Reputation: 337
Suppose a model as in:
model = Model(inputs=[A, B], outputs=C)
With custom loss:
def actor_loss(y_true, y_pred):
log_lik = y_true * K.log(y_pred)
loss = -K.sum(log_lik * K.stop_gradient(B))
return loss
Now I'm trying to define a function that returns the gradients of the loss wrt to the weights for a given pair of input and target output and expose it as such. Here is an idea of what I mean in pseudocode
def _get_grads(inputs, targets):
loss = model.loss(targets, model.output)
weights = model.trainable_weights
grads = K.gradients(loss, weights)
model.input[0] (aka 'A') <----inputs[0]
model.input[1] (aka 'B') <----inputs[1]
return K.function(model.input, grads)
self.get_grads = _get_grads
My question is how do I feed inputs argument to the graph inside said function. (So far I've only worked with .fit and not with .gradients and I can't find any decent documentation with custom loss or multiple inputs)
Upvotes: 1
Views: 291
Reputation: 337
My understanding of K.function ,K.gradients and custom loss was fundamentally wrong. You use the function to construct a mini-graph that computes gradients of loss wrt to weights. No need for the function itself to have arguments.
def _get_grads():
targets = Input(shape=...)
loss = model.loss(targets, model.output)
weights = model.trainable_weights
grads = K.gradients(loss, weights)
return K.function(model.input + [targets], grads)
I was under the impression that _get_grads was itself K.function but that was wrong. _get_grads() returns K.function. And then you use that as
f = _get_grads() # constructs the mini-graph that gives gradients
grads = f([inputs, labels])
inputs is fed to model.inputs, labels to targets and it returns grads.
Upvotes: 0
Reputation: 56367
If you call K.function
, you get an actual callable function, so you should just call it with some parameter values. The format is exactly the same as model.fit
, in your case it should be two arrays of values, including the batch dimension:
self.get_grads = _get_grads(inputs, targets)
grad_value = self.get_grads([input1, input2])
Where input1
and input2
are numpy arrays that include the batch dimension.
Upvotes: 2