Reputation: 11
I'm currently training a network for multi-image super resolution based on the RAMS network. I'm using a different, pre-trained model as my metric. In my training function I'm trying to calculate gradient:
def train_step(self, lr, hr, mask):
lr = tf.cast(lr, tf.float32)
with tf.GradientTape() as tape:
sr = self.checkpoint.model(lr, training=True)
loss = self.loss(hr, sr, mask, self.image_hr_size)
print(loss.shape)
gradients = tape.gradient(
loss, self.checkpoint.model.trainable_variables)
self.checkpoint.optimizer.apply_gradients(
zip(gradients, self.checkpoint.model.trainable_variables))
Inside loss() function I use model.predict() on batches of my data. I get an error:
LookupError: No gradient defined for operation 'IteratorGetNext' (op type: IteratorGetNext)
In similar threads I found a suggestion to replace model.predict() with model() - the code runs, but it is very slow (6h per epoch). Is there a way I can use model.predict() while calculating gradient or speed the process up in any way?
Upvotes: 1
Views: 373
Reputation: 12939
predict
is for inference time, just do it how you are doing it in the train loop by doing loss(self.checkpoint.model(whatever, training=True), targets)
About the is slow, that does not depend on the train loop, but is about your pc, the batch size, the model size, the data type and so on
Upvotes: 2