Reputation: 766
def train_one_step():
with tf.GradientTape() as tape:
a = tf.random.normal([1, 3, 1])
b = tf.random.normal([1, 3, 1])
loss = mse(a, b)
tf.print('inner tf print', loss)
print("inner py print", loss)
return loss
@tf.function
def train():
loss = train_one_step()
tf.print('outer tf print', loss)
print('outer py print', loss)
return loss
loss = train()
tf.print('outest tf print', loss)
print("outest py print", loss)
I'm trying to understand tf.functional more. I printed the loss in four places with different methods. and it produces results like this
inner py print Tensor("mean_absolute_error/weighted_loss/value:0", shape=(), dtype=float32)
outer py print Tensor("mean_absolute_error/weighted_loss/value:0", shape=(), dtype=float32)
inner tf print 1.82858419
outer tf print 1.82858419
outest tf print 1.82858419
outest py print tf.Tensor(1.8285842, shape=(), dtype=float32)
Upvotes: 5
Views: 3399
Reputation: 27050
I covered and answered all your questions in a three-part article: "Analyzing tf.function to discover AutoGraph strengths and subtleties": part 1, part 2, part 3.
To summarize and answer your 3 questions:
tf.print
is a Tensorflow construct, that prints on standard error by default and, more importantly, it produces an operation when evaluated.
When an operation is run, in eager execution too, it produces a "node" more or less in the same way of Tensorflow 1.x.
tf.function
is able to capture the generated operation of tf.print
and convert it to a graph node.
On the contrary, print
is a Python construct that prints on standard output by default and do not generate any operation when is executed. Therefore, tf.function
is not able to convert it in its graph equivalent and executes it only during the function tracing.
I have answered to this question in the previous point, but once again, print
is executed only during the function tracing, while tf.print
is executed both, during the tracing and when its graph-representation is executed (after tf.function
successfully converted the function to a graph).
Yes. tf.print
does not run before or after print
. In eager execution, they are evaluated as soon as the Python interpreter founds the statement. The only difference in eager execution is the output stream.
At any rate, I suggest you read the three articles linked since they cover in detail this and other peculiarities of tf.function
.
Upvotes: 5
Reputation: 3570
print
is the normal python print. tf.print
is part of the tensorflow graph.
In eager mode tensorflow will execute the graph directly. That is why outside of your @tf.function
function, the output of the python print is a number (tensorflow executes the graph directly and gives out the number to the normal print function) and that is also why tf.print prints immediately.
On the other hand, inside the @tf.function
function tensorflow will not execute the graph immediately. Instead it will "stack" the tensorflow functions that you call into a larger graph which will we executed in the end of @tf.function
all at once.
That is why the python print does not give you the number inside the @tf.function
function (the graph is not yet executed at that point). But after the function is over, the graph is being executed, together with the tf.print
in the graph. Therefore the tf.print
is printing after the python print and gives you the actual loss numbers.
Upvotes: 2