Ethan Yanjia Li
Ethan Yanjia Li

Reputation: 766

Try to understand AutoGraph and tf.function: print loss in tf.function

def train_one_step():
    with tf.GradientTape() as tape:
        a = tf.random.normal([1, 3, 1])
        b = tf.random.normal([1, 3, 1])
        loss = mse(a, b)

    tf.print('inner tf print', loss)
    print("inner py print", loss)

    return loss


@tf.function
def train():
    loss = train_one_step()

    tf.print('outer tf print', loss)
    print('outer py print', loss)

    return loss

loss = train()
tf.print('outest tf print', loss)
print("outest py print", loss)

I'm trying to understand tf.functional more. I printed the loss in four places with different methods. and it produces results like this

inner py print Tensor("mean_absolute_error/weighted_loss/value:0", shape=(), dtype=float32)
outer py print Tensor("mean_absolute_error/weighted_loss/value:0", shape=(), dtype=float32)
inner tf print 1.82858419
outer tf print 1.82858419
outest tf print 1.82858419
outest py print tf.Tensor(1.8285842, shape=(), dtype=float32)
  1. what's the difference between tf.print and python print?
  2. looks like python print will be executed during graph evaluation, but tf print only when executed?
  3. the above only apply when there's tf.function decorator? outside that, tf.print runs before python print?

Upvotes: 5

Views: 3399

Answers (2)

nessuno
nessuno

Reputation: 27050

I covered and answered all your questions in a three-part article: "Analyzing tf.function to discover AutoGraph strengths and subtleties": part 1, part 2, part 3.

To summarize and answer your 3 questions:

  • what's the difference between tf.print and python print?

tf.print is a Tensorflow construct, that prints on standard error by default and, more importantly, it produces an operation when evaluated.

When an operation is run, in eager execution too, it produces a "node" more or less in the same way of Tensorflow 1.x.

tf.function is able to capture the generated operation of tf.print and convert it to a graph node.

On the contrary, print is a Python construct that prints on standard output by default and do not generate any operation when is executed. Therefore, tf.function is not able to convert it in its graph equivalent and executes it only during the function tracing.

  • looks like python print will be executed during graph evaluation, but tf print only when executed?

I have answered to this question in the previous point, but once again, print is executed only during the function tracing, while tf.print is executed both, during the tracing and when its graph-representation is executed (after tf.function successfully converted the function to a graph).

  • the above only apply when there's tf.function decorator? outside that, tf.print runs before python print?

Yes. tf.print does not run before or after print. In eager execution, they are evaluated as soon as the Python interpreter founds the statement. The only difference in eager execution is the output stream.

At any rate, I suggest you read the three articles linked since they cover in detail this and other peculiarities of tf.function.

Upvotes: 5

BlueSun
BlueSun

Reputation: 3570

print is the normal python print. tf.print is part of the tensorflow graph. In eager mode tensorflow will execute the graph directly. That is why outside of your @tf.function function, the output of the python print is a number (tensorflow executes the graph directly and gives out the number to the normal print function) and that is also why tf.print prints immediately.

On the other hand, inside the @tf.function function tensorflow will not execute the graph immediately. Instead it will "stack" the tensorflow functions that you call into a larger graph which will we executed in the end of @tf.function all at once.

That is why the python print does not give you the number inside the @tf.function function (the graph is not yet executed at that point). But after the function is over, the graph is being executed, together with the tf.print in the graph. Therefore the tf.print is printing after the python print and gives you the actual loss numbers.

Upvotes: 2

Related Questions