Reputation: 77
I am trying to construct a custom loss for a regression problem with the following structure, following this answer: Keras Custom loss function to pass arguments other than y_true and y_pred
Now, my function is like the following:
def CustomLoss(model,X_valid,y_valid,batch_size):
def Loss(y_true,y_pred):
n_samples=5
mc_predictions = np.zeros((n_samples,256,256))
for i in range(n_samples):
y_p = model.predict(X_valid, verbose=1,batch_size=batch_size)
(Other operations...)
return LossValue
return Loss
When trying to execute this line
y_p = model.predict(X_valid, verbose=1,batch_size=batch_size)
i get the following error:
Method requires being in cross-replica context, use get_replica_context().merge_call()
From what I gathered I cannot use model.predict inside loss function. Is there a workaround or solution for this? Please let me know if my question is clear or if you need any additional information. Thanks!
Upvotes: 2
Views: 1683
Reputation: 429
Sounds like you can use model.add_loss for this. You can use this to specify the loss function inside of the model. It also removes the need for the loss function to only take in y and y_pred. https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#add_loss \ Some psuedo-code:
class YourModel(tf.keras.Model):
...
def call(self, inputs):
unpack, any, extra, stuff = inputs
(your network code goes here)
loss = (other operations)
self.add_loss(loss)
return output
(In case you don't know, model.predict is basically just model.call but with some extra bells and whistles attached.)
Upvotes: 1