mohhao
mohhao

Reputation: 35

How to wrap tf.cond function with keras.layers.Lambda?

I'm trying to define a custom layer in keras,but I can't find a way to warp tf.cond with layers.Lambda function

        matches = tf.cond(
            tf.greater(N, 0),
            lambda: match_boxes(
                anchors, groundtruth_boxes,
                positives_threshold=positives_threshold,
                negatives_threshold=negatives_threshold,
                force_match_groundtruth=True
            ),
            lambda: only_background
        )

Upvotes: 2

Views: 1277

Answers (1)

Vlad
Vlad

Reputation: 8595

Since the body of your true function is very big, you could create a custom layer like this:

import tensorflow as tf

class CustomLayer(tf.keras.layers.Layer):

  def __init__(self, **kwargs):
    super(CustomLayer, self).__init__()
    self.pred = kwargs.get('pred', False)

  def call(self, inputs):
    def true_fn(x):
      return x + 1.

    return tf.cond(self.pred,
                   true_fn=lambda: true_fn(inputs),
                   false_fn=lambda: tf.identity(inputs))

Testing:

inputs = tf.placeholder(tf.float32, shape=(None, 1))
pred = tf.placeholder(tf.bool, shape=())

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(1, kernel_initializer=tf.initializers.ones))
model.add(CustomLayer(pred=pred))

outputs = model(inputs)

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  print(outputs.eval({inputs: [[1.]], pred: False})) # [[1.]]
  print(outputs.eval({inputs: [[1.]], pred: True})) # [[2.]]

Upvotes: 1

Related Questions