Reputation: 48495
I want to build a custom layer using tf.keras. For simplicity, suppose it should return inputs*2 during training and inputs*3 during testing. What is the correct way to do this?
I tried this approach:
class CustomLayer(Layer):
@tf.function
def call(self, inputs, training=None):
if training:
return inputs*2
else:
return inputs*3
I can then use this class like this:
>>> layer = CustomLayer()
>>> layer(10)
tf.Tensor(30, shape=(), dtype=int32)
>>> layer(10, training=True)
tf.Tensor(20, shape=(), dtype=int32)
It works fine! However, when I use this class in a model, and I call its fit()
method, it seems that training
is not set to True
. I tried to add the following code at the beginning of the call()
method, but training
is always equal to 0.
if training is None:
training = K.learning_phase()
What am I missing?
Edit
I found a solution (see my answer), but I'm still looking for a nicer solution using @tf.function
(I prefer autograph to this smart_cond()
business). Unfortunately, it looks like K.learning_phase()
does not play nice with @tf.function
(my guess is that when the call()
function gets traced, the learning phase gets hard-coded into the graph: since this happens before the call to the fit()
method, the learning phase is always 0). This may be a bug, or perhaps there's another way to get the learning phase when using @tf.function
.
Upvotes: 1
Views: 1128
Reputation: 48495
François Chollet confirmed that the correct solution when using @tf.function
is:
class CustomLayer(Layer):
@tf.function
def call(self, inputs, training=None):
if training is None:
training = K.learning_phase()
if training:
return inputs * 2
else:
return inputs * 3
There's currently a bug (as of Feb 15th 2019) that makes training
always equal to 0
, but this will be fixed shortly.
Upvotes: 2
Reputation: 48495
The following code does not use @tf.function
, so it does not look as nice (since it does not use autograph), but it works fine:
from tensorflow.python.keras.utils.tf_utils import smart_cond
class CustomLayer(Layer):
def call(self, inputs, training=None):
if training is None:
training = K.learning_phase()
return smart_cond(training, lambda: inputs * 2, lambda: inputs * 3)
Upvotes: 0