Reputation: 1387
I am switching to TF2 and I just followed this tutorial, where the train and step functions are defined now as "@tf.function".
How can I print the values of the Tensors y_pred and loss?
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_object(labels, predictions)
print("train preds: ", y_pred)
print("train loss: ", loss)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, predictions)
Upvotes: 0
Views: 1439
Reputation: 2679
print
executes in the Python world (not in graph), so it will only print the tensors once while tf.function
is tracing your function to construct a graph. If you want to print in-graph, use tf.print
.
Upvotes: 3