Imago
Imago

Reputation: 501

Propagating through a custom layer in tensorflow just once

Given a custom layer in tensorflow, is it possible to let the model use it just during one epoch? The layer may just be ignored for all other epochs or simple be an identity.

For example: Given data I would like the layer to simply double the given data. The other layers should may work normally. How would one do that?

def do_stuff(data):
      return 2*data

def run_once(data):
  return tf.py_func(do_stuff, 
                     [data],
                     'float32',
                     stateful=False,
                     name='I run once')


class CustomLayer(Layer):
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    self.trainable = False
    super(CustomLayer, self).__init__(**kwargs)

  def call(self, x):
    res = tf.map_fn(run_once, x)
    res.set_shape([x.shape[0],
                   self.output_dim[1], 
                   self.output_dim[0],
                   x.shape[-1]])
    return res

inputs = Input(shape=(224, 224, 1))    
x = Lambda(preprocess_input(x), input_shape=(224, 224, 1), output_shape=(224, 224, 3))
outputs = Dense(1)(x)
model = Model(input=inputs, output=outputs)
output = model(x)

Upvotes: 1

Views: 259

Answers (1)

rvinas
rvinas

Reputation: 11895

Interesting question. To execute a TF operation just in the first epoch, one could use tf.cond and tf.control_dependencies to check/update the value of a boolean tensor. For example, your custom layer could be implemented as follows:

class CustomLayer(Layer):
    def __init__(self, **kwargs):
        super(CustomLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.first_epoch = tf.Variable(True)

    def call(self, x):
        res = tf.cond(self.first_epoch,
                      true_fn=lambda: run_once(x),
                      false_fn=lambda: x)
        with tf.control_dependencies([res]):
            assign_op = self.first_epoch.assign(False)
            with tf.control_dependencies([assign_op]):
                res = tf.identity(res)
        return res

To validate that this layer works as expected, define run_once as:

def run_once(data):
    print_op = tf.print('First epoch')
    with tf.control_dependencies([print_op]):
        out = tf.identity(data)
    return out

Upvotes: 1

Related Questions