arash javanmard
arash javanmard

Reputation: 1387

Tensorflow 2: Getting Tensor Value

I am switching to TF2 and I just followed this tutorial, where the train and step functions are defined now as "@tf.function".

How can I print the values of the Tensors y_pred and loss?

@tf.function
def train_step(images, labels):
  with tf.GradientTape() as tape:
    predictions = model(images)
    loss = loss_object(labels, predictions)

    print("train preds: ", y_pred)
    print("train loss: ", loss)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_loss(loss)
  train_accuracy(labels, predictions)

Upvotes: 0

Views: 1439

Answers (1)

Sergei Lebedev
Sergei Lebedev

Reputation: 2679

print executes in the Python world (not in graph), so it will only print the tensors once while tf.function is tracing your function to construct a graph. If you want to print in-graph, use tf.print.

Upvotes: 3

Related Questions