Eduardo
Eduardo

Reputation: 77

How to model.predict inside loss function? (Tensorflow, Keras)

I am trying to construct a custom loss for a regression problem with the following structure, following this answer: Keras Custom loss function to pass arguments other than y_true and y_pred

Now, my function is like the following:

def CustomLoss(model,X_valid,y_valid,batch_size):
    def Loss(y_true,y_pred):
        n_samples=5
        mc_predictions = np.zeros((n_samples,256,256))
        for i in range(n_samples):
           y_p = model.predict(X_valid, verbose=1,batch_size=batch_size)
    (Other operations...) 
        return LossValue
    return Loss

When trying to execute this line y_p = model.predict(X_valid, verbose=1,batch_size=batch_size) i get the following error:

Method requires being in cross-replica context, use get_replica_context().merge_call()

From what I gathered I cannot use model.predict inside loss function. Is there a workaround or solution for this? Please let me know if my question is clear or if you need any additional information. Thanks!

Upvotes: 2

Views: 1683

Answers (1)

Taw
Taw

Reputation: 429

Sounds like you can use model.add_loss for this. You can use this to specify the loss function inside of the model. It also removes the need for the loss function to only take in y and y_pred. https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#add_loss \ Some psuedo-code:

class YourModel(tf.keras.Model): 
    ...
    def call(self, inputs): 
        unpack, any, extra, stuff = inputs
        (your network code goes here)
        loss = (other operations)
        self.add_loss(loss)
        return output

(In case you don't know, model.predict is basically just model.call but with some extra bells and whistles attached.)

Upvotes: 1

Related Questions