Reputation: 3079
I have a custom tf.keras.layers.Layer
which do some kind of bit unpacking (converting integers to booleans values (0 or 1 float)) using only TF operators.
class CharUnpack(keras.layers.Layer):
def __init__(self, name="CharUnpack", *args, **kwargs):
super(CharUnpack, self).__init__(trainable=False, name=name, *args, **kwargs)
# Range [7, 6, ..., 0] to bit-shift integers
self._shifting_range = tf.reshape(
tf.dtypes.cast(
tf.range(7, -1, -1, name='shifter_range'),
tf.uint8,
name='shifter_cast'),
(1, 1, 8),
name='shifter_reshape')
# Constant value 0b00000001 to use as bitwise and operator
self._selection_bit = tf.constant(0x01, dtype=tf.uint8, name='and_selection_bit')
def call(self, inputs):
return tf.dtypes.cast(
tf.reshape(
tf.bitwise.bitwise_and(
tf.bitwise.right_shift(
tf.expand_dims(inputs, 2),
self._shifting_range,
),
self._selection_bit,
),
[x if x else -1 for x in self.compute_output_shape(inputs.shape)]
),
tf.float32
)
def compute_output_shape(self, input_shape):
try:
if len(input_shape) > 1:
output_shape = tf.TensorShape(tuple(list(input_shape[:-1]) + [input_shape[-1] * 8]))
else:
output_shape = tf.TensorShape((input_shape[0] * 8,))
except TypeError:
output_shape = input_shape
return output_shape
def compute_output_signature(self, input_signature):
return tf.TensorSpec(self.compute_output_shape(input_signature.shape), tf.float32)
I tried to benchmark this layer to improve the time performance as shown in this TF guide.
inputs = tf.zeros([64, 400], dtype=tf.uint8)
eager = CharUnpack()
@tf.function
def fun(x):
eager(x)
# Warm-up
eager(inputs)
fun(inputs)
print("Function:", timeit.timeit(lambda: fun(inputs), number=100))
print("Eager:", timeit.timeit(lambda: eager(inputs), number=100))
Function: 0.01062483999885444
Eager: 0.12658399900101358
As you can see, I can get a 10 times speed-up!!!
So, I added the @tf.function
decorator to my CharUnpack.call
method:
+ @tf.function
def call(self, inputs):
return tf.dtypes.cast(
Now I expect both, the eager
and the fun
, calls to spend similar time, but I get no improvement.
Function: 0.009667591999459546
Eager: 0.10346330100037449
Moreover, in section 2.1 of this SO answer states that Models are graph-compiled by default (which should be logic), but this does not seem to be the case...
How to properly use the @tf.function
decorator to make my layer always graph-compiled?
Upvotes: 10
Views: 2180
Reputation: 11333
tldr: fun()
doesn't return anything, tensorflow
autograph is smart enough to realize this and ignores all the TF computations happening within fun()
, whereas eager(x)
has to execute what's happening in the call()
function. This is why you're getting a ridiculously low execution time for fun()
. At least that's what I think is happening - I'm not an AutoGraph expert so others might be able to correct me if I've got anything wrong.
Before we dive in, let's simplify things a git. First I modified your original code as follows. Let's increase the size of the data to make sure there's enough number crunching involved and data transfers and other overheads are not dominating the profiling.
inputs = tf.zeros([8192, 400], dtype=tf.uint8)
Second, I stripped out some computations e.g. compute_output_shape()
and pinned the shape. Also brought some tensor definitions inside call()
. So that call()
takes care of variable definitions to computations end-to-end.
class CharUnpack(tf.keras.layers.Layer):
def __init__(self, name="CharUnpack", *args, **kwargs):
super(CharUnpack, self).__init__(trainable=False, name=name, *args, **kwargs)
self._shifting_range = None
self._selection_bit = None
@tf.function
def call(self, inputs):
if not self._shifting_range:
# Range [7, 6, ..., 0] to bit-shift integers
self._shifting_range = tf.reshape(
tf.dtypes.cast(
tf.range(7, -1, -1, name='shifter_range'),
tf.uint8,
name='shifter_cast'
),
(1, 1, 8),
name='shifter_reshape')
if not self._selection_bit:
# Constant value 0b00000001 to use as bitwise and operator
self._selection_bit = tf.constant(0x01, dtype=tf.uint8, name='and_selection_bit')
return tf.dtypes.cast(
tf.reshape(
tf.bitwise.bitwise_and(
tf.bitwise.right_shift(
tf.expand_dims(inputs, 2),
self._shifting_range,
),
self._selection_bit,
),
[x if x else -1 for x in self.compute_output_shape(inputs.shape)]
),
tf.float32
)
def compute_output_shape(self, input_shape):
return [8192, 3200]
Thirdly, I've set number=1
in the timeit operation to make sure we're profiling a single call at a time. This makes it easier to understand.
# The very first call of either approach
print("Eager:", timeit.timeit(lambda: eager(inputs), number=1))
print("Function:", timeit.timeit(lambda: fun(inputs), number=1))
# The second call
print("Eager:", timeit.timeit(lambda: eager(inputs), number=1))
print("Function:", timeit.timeit(lambda: fun(inputs), number=1))
First let's look at the concrete function of eager()
eager_concrete = eager.call.get_concrete_function(tf.TensorSpec(shape=[None, 400], dtype=tf.uint8))
print(eager_concrete)
which gives,
ConcreteFunction call(inputs)
Args:
inputs: uint8 Tensor, shape=(None, 400)
Returns:
float32 Tensor, shape=(8192, 3200)
Let's look at the concrete function of fun()
fun_concrete = fun.get_concrete_function(tf.TensorSpec(shape=[None, 400], dtype=tf.uint8))
print(fun_concrete)
which gives,
ConcreteFunction fun(x)
Args:
x: uint8 Tensor, shape=(None, 400)
Returns:
NoneTensorSpec()
So you straight away see that fun()
is not returning anything, which should raise red flags in your mind. Going a bit further, we can look at what actually entails the graph that resulted from AutoGraph's tracing.
graph = fun_concrete.graph
for node in graph.as_graph_def().node:
print(f'{node.input} -> {node.name}')
which outputs,
[] -> x
['x'] -> CharUnpack/StatefulPartitionedCall
Next, if you do the same for eager()
, you will see all the primitive TF operations listed as below.
[] -> inputs
[] -> StringFormat
['StringFormat'] -> PrintV2
[] -> shifter_range/start
...
['Reshape'] -> Cast
['Cast', '^NoOp'] -> Identity
We can even look at the generated code.
print(tf.autograph.to_code(fun.python_function))
which gives,
def tf__fun(x):
with ag__.FunctionScope('fun', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
out = ag__.converted_call(ag__.ld(eager), (ag__.ld(x),), None, fscope)
So looking at the code, all it does is generating a converted call for eager
and x
I'm not an AutoGraph expert but I imagine that all it's doing is passing a given input x
to eager.call()
and skipping all of the computations. So fun()
is just skipping all the important computations in eager.call()
function.
fun()
actually do the computations?Simply add a return
statement to fun()
.
@tf.function
def fun(x):
out = eager(x)
return out
which gives,
Eager: 0.6245606249999582
Function: 0.3163724480000383
Eager: 0.2076279070001874
Function: 0.22467646699988109
Eager: 0.25076841500003866
Function: 0.240701412999897
So now we can see that both eager.call()
and fun()
taking the same time.
From TF documentation it says,
With the exception of tf.Variables, a tf.function must return all its outputs.
Though the section is highlighting a different face of the problem, it's possibly (indirectly) related to what's going on here.
Upvotes: 1
Reputation: 902
First, tf.function
does not need nested using, i.e., you can only wrap your custom train_step()
(contain propagation). In this case, there is no need to wrap inner layer or sub model's call()
function, since they are involved in your train_step
. Nested useage may lead to some unexpected performance degradation.
Second, any computational acceleration comes at a cost, tf.function
is a way of exchanging space for time and need initialization to build Graph
. So, when benchmark, we should re-run the same function for several times, since a secondary call of tf.function
do not cost building time as long as Tracing
changes nothing.
for _ in range(5):
print("Function:", timeit.timeit(lambda: fun(inputs), number=100))
for _ in range(5):
print("Eager:", timeit.timeit(lambda: eager(inputs), number=100))
# Function: 0.02040819999820087
# Function: 0.020311099986429326
# Function: 0.020155799997155555
# Function: 0.02004839999426622
# Function: 0.019999900003313087
# Eager: 0.035980800006655045
# Eager: 0.035652499995194376
# Eager: 0.035596200003055856
# Eager: 0.03490520000923425
# Eager: 0.03762050000659656
Upvotes: 0