包乔奔
包乔奔

Reputation: 53

How does tf.function works when two functions call each other

I build my model using tensorflow==1.14 with tf.enable_eager_execution() like this:

class Model:
  def __init__(self):
    self.embedding = tf.keras.layers.Embedding(10, 15)
    self.dense = tf.keras.layers.Dense(10)

  @tf.function
  def inference(self, inp):
    print('call function: inference')
    inp_em = self.embedding(inp)
    inp_enc = self.dense(inp_em)

    return inp_enc

  @tf.function
  def fun(self, inp):
    print('call function: fun')
    return self.inference(inp)

model = Model()

when I ran the following code for first time:

a = model.fun(np.array([1, 2, 3]))
print('=' * 20)
a = model.inference(np.array([1, 2, 3]))

the output is

call function: fun
call function: inference
call function: inference
====================
call function: inference

it seems like tensorflow build three graphs for inference function, how can I just build just one graph for inference function. And I also want to know how tf.function woks when two functions call each other. Is this the right way to build my model?

Upvotes: 4

Views: 1561

Answers (1)

Stewart_R
Stewart_R

Reputation: 14495

Sometimes the way tf.function executes can cause us a little confusion - particularly when we mix in vanilla python operations such as print().

We should remember that when we decorate a function with tf.function it's no longer just a python function. It behaves a little differently in order to enable fast and efficient use in TF. The vast majority of the time, the changes in behaviour are pretty much unnoticeable (except for the increased speed!) but occasionally we can encounter a little nuance like this.

The first thing to note is that if we use tf.print() in place of print() then we get the expected output:

class Model:
  def __init__(self):
    self.embedding = tf.keras.layers.Embedding(10, 15)
    self.dense = tf.keras.layers.Dense(10)

  @tf.function
  def inference(self, inp):
    tf.print('call function: inference')
    inp_em = self.embedding(inp)
    inp_enc = self.dense(inp_em)

    return inp_enc

  @tf.function
  def fun(self, inp):
    tf.print('call function: fun')
    return self.inference(inp)

model = Model()

a = model.fun(np.array([1, 2, 3]))
print('=' * 20)
a = model.inference(np.array([1, 2, 3]))

outputs:

call function: fun
call function: inference
====================
call function: inference

If your question is the symptom of a real world problem this is probably the fix!

So what's going on?

Well the first time we call a function decorated with tf.function tensorflow will build an execution graph. In order to do so it "traces" the tensorflow operations executed by the python function.

In order to do this tracing it is possible that tensorflow will call the decorated function more than once!

This means that the python only operations (such as print() could get executed more than once) but tf operations such as tf.print() will behave as you would normally expect.

A side effect of this nuance is that we ought to be aware of how tf.function decorated functions handle state but this is outside of the scope of your question. See the original RFC and this github issue for more info.

And I also want to know how tf.function woks when two functions call each other. Is this the right way to build my model?

In general, we need only decorate the "outer" function with tf.function (.fun() in your example) but if you could call the inner function directly too then you are free to decorate that too.

Upvotes: 2

Related Questions