zephyr3319
zephyr3319

Reputation: 11

Is there a way to use model.predict() while calculating gradient?

I'm currently training a network for multi-image super resolution based on the RAMS network. I'm using a different, pre-trained model as my metric. In my training function I'm trying to calculate gradient:

def train_step(self, lr, hr, mask):
    lr = tf.cast(lr, tf.float32)
    
    with tf.GradientTape() as tape:
        sr = self.checkpoint.model(lr, training=True)
        loss = self.loss(hr, sr, mask, self.image_hr_size)
        print(loss.shape)
        
    gradients = tape.gradient(
        loss, self.checkpoint.model.trainable_variables)
    self.checkpoint.optimizer.apply_gradients(
        zip(gradients, self.checkpoint.model.trainable_variables))

Inside loss() function I use model.predict() on batches of my data. I get an error:

LookupError: No gradient defined for operation 'IteratorGetNext' (op type: IteratorGetNext)

In similar threads I found a suggestion to replace model.predict() with model() - the code runs, but it is very slow (6h per epoch). Is there a way I can use model.predict() while calculating gradient or speed the process up in any way?

Upvotes: 1

Views: 373

Answers (1)

Alberto
Alberto

Reputation: 12939

predict is for inference time, just do it how you are doing it in the train loop by doing loss(self.checkpoint.model(whatever, training=True), targets)

About the is slow, that does not depend on the train loop, but is about your pc, the batch size, the model size, the data type and so on

Upvotes: 2

Related Questions