Reputation: 53
I build my model using tensorflow==1.14 with tf.enable_eager_execution() like this:
class Model:
def __init__(self):
self.embedding = tf.keras.layers.Embedding(10, 15)
self.dense = tf.keras.layers.Dense(10)
@tf.function
def inference(self, inp):
print('call function: inference')
inp_em = self.embedding(inp)
inp_enc = self.dense(inp_em)
return inp_enc
@tf.function
def fun(self, inp):
print('call function: fun')
return self.inference(inp)
model = Model()
when I ran the following code for first time:
a = model.fun(np.array([1, 2, 3]))
print('=' * 20)
a = model.inference(np.array([1, 2, 3]))
the output is
call function: fun
call function: inference
call function: inference
====================
call function: inference
it seems like tensorflow build three graphs for inference function, how can I just build just one graph for inference function. And I also want to know how tf.function woks when two functions call each other. Is this the right way to build my model?
Upvotes: 4
Views: 1561
Reputation: 14495
Sometimes the way tf.function
executes can cause us a little confusion - particularly when we mix in vanilla python operations such as print()
.
We should remember that when we decorate a function with tf.function
it's no longer just a python function. It behaves a little differently in order to enable fast and efficient use in TF. The vast majority of the time, the changes in behaviour are pretty much unnoticeable (except for the increased speed!) but occasionally we can encounter a little nuance like this.
The first thing to note is that if we use tf.print()
in place of print()
then we get the expected output:
class Model:
def __init__(self):
self.embedding = tf.keras.layers.Embedding(10, 15)
self.dense = tf.keras.layers.Dense(10)
@tf.function
def inference(self, inp):
tf.print('call function: inference')
inp_em = self.embedding(inp)
inp_enc = self.dense(inp_em)
return inp_enc
@tf.function
def fun(self, inp):
tf.print('call function: fun')
return self.inference(inp)
model = Model()
a = model.fun(np.array([1, 2, 3]))
print('=' * 20)
a = model.inference(np.array([1, 2, 3]))
outputs:
call function: fun
call function: inference
====================
call function: inference
If your question is the symptom of a real world problem this is probably the fix!
So what's going on?
Well the first time we call a function decorated with tf.function
tensorflow will build an execution graph. In order to do so it "traces" the tensorflow operations executed by the python function.
In order to do this tracing it is possible that tensorflow will call the decorated function more than once!
This means that the python only operations (such as print()
could get executed more than once) but tf operations such as tf.print()
will behave as you would normally expect.
A side effect of this nuance is that we ought to be aware of how tf.function
decorated functions handle state but this is outside of the scope of your question. See the original RFC and this github issue for more info.
And I also want to know how tf.function woks when two functions call each other. Is this the right way to build my model?
In general, we need only decorate the "outer" function with tf.function
(.fun()
in your example) but if you could call the inner function directly too then you are free to decorate that too.
Upvotes: 2