data_person
data_person

Reputation: 4490

Does tf.function in tensorflow optimize run time?

I read that using tf. function can optimize run time by creating a graph and lazy evaluation. Following is a sample code of doing matrix multiplication:

import tensorflow as tf

def random_sample(x1,x2):
    return tf.matmul(x1,x2)

@tf.function
def random_sample_optimized(x1,x2):
    return tf.matmul(x1,x2)

x1 = tf.constant(tf.random.normal(shape=(3999,29999)))
x2 = tf.constant(tf.random.normal(shape=(29999,3999)))

Calculating run time:

import time
start = time.time()
op1 = random_sample(x1,x2)
end = time.time()
print (end-start) ## op ~avg = 7 secs

start = time.time()
op2 = random_sample_optimized(x1,x2)
end = time.time()
print (end-start) ##op ~avg = 9.5 secs

Not only the average is high when using tf.function the individual run_time on each and every run was high when using tf.function.

Any suggestions on if I am using tf.function correctly or it provides optimization only when building complex neural nets?

Upvotes: 1

Views: 274

Answers (1)

Abhilash Rajan
Abhilash Rajan

Reputation: 349

The tf.function will usually run faster for complex calculations. Quoting the following lines from the book - 'Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow' along with the example:

A TF Function will usually run much faster than the original Python function, especially if it performs complex computations. However, in this trivial example, the computation graph is so small that there is nothing at all to optimize, so tf_cube() actually runs much slower than cube().

def cube(x):
    return x ** 3

>>> cube(2)
8
>>> cube(tf.constant(2.0))
<tf.Tensor: id=18634148, shape=(), dtype=float32, numpy=8.0>

Now, let’s use tf.function() to convert this Python function to a TensorFlow Function:

>>> tf_cube = tf.function(cube)
>>> tf_cube
<tensorflow.python.eager.def_function.Function at 0x1546fc080>

>>> tf_cube(2)
<tf.Tensor: id=18634201, shape=(), dtype=int32, numpy=8>
>>> tf_cube(tf.constant(2.0))
<tf.Tensor: id=18634211, shape=(), dtype=float32, numpy=8.0>

Upvotes: 3

Related Questions