MiniQuark
MiniQuark

Reputation: 48495

Using tf.keras in TF 2.0, how can I define a custom layer that depends on the learning phase?

I want to build a custom layer using tf.keras. For simplicity, suppose it should return inputs*2 during training and inputs*3 during testing. What is the correct way to do this?

I tried this approach:

class CustomLayer(Layer):
    @tf.function
    def call(self, inputs, training=None):
        if training:
            return inputs*2
        else:
            return inputs*3

I can then use this class like this:

>>> layer = CustomLayer()
>>> layer(10)
tf.Tensor(30, shape=(), dtype=int32)
>>> layer(10, training=True)
tf.Tensor(20, shape=(), dtype=int32)

It works fine! However, when I use this class in a model, and I call its fit() method, it seems that training is not set to True. I tried to add the following code at the beginning of the call() method, but training is always equal to 0.

if training is None:
    training = K.learning_phase()

What am I missing?

Edit

I found a solution (see my answer), but I'm still looking for a nicer solution using @tf.function (I prefer autograph to this smart_cond() business). Unfortunately, it looks like K.learning_phase() does not play nice with @tf.function (my guess is that when the call() function gets traced, the learning phase gets hard-coded into the graph: since this happens before the call to the fit() method, the learning phase is always 0). This may be a bug, or perhaps there's another way to get the learning phase when using @tf.function.

Upvotes: 1

Views: 1128

Answers (2)

MiniQuark
MiniQuark

Reputation: 48495

François Chollet confirmed that the correct solution when using @tf.function is:

class CustomLayer(Layer):
    @tf.function
    def call(self, inputs, training=None):
        if training is None:
            training = K.learning_phase()
        if training:
            return inputs * 2
        else:
            return inputs * 3

There's currently a bug (as of Feb 15th 2019) that makes training always equal to 0, but this will be fixed shortly.

Upvotes: 2

MiniQuark
MiniQuark

Reputation: 48495

The following code does not use @tf.function, so it does not look as nice (since it does not use autograph), but it works fine:

from tensorflow.python.keras.utils.tf_utils import smart_cond

class CustomLayer(Layer):
    def call(self, inputs, training=None):
        if training is None:
            training = K.learning_phase()
        return smart_cond(training, lambda: inputs * 2, lambda: inputs * 3)

Upvotes: 0

Related Questions