Christopher Mills
Christopher Mills

Reputation: 760

tf.cond() returning a tensor of shape unknown

Below is the function that I am passing to a keras Lambda layer.

I am getting a problem with the output of tf.cond(). It returns a shape of <unknown>. The input tensor (t) and the constant weight tensor have shapes of (None,6) and (6,), respectively. When I add these two outside of tf.cond() then I get a tensor of shape (None,6), which is what I need it to be. However, when the same add operation is returned from within tf.cond(), I get a tensor of shape <unknown>.

What changes when this operation goes via tf.cond().

def class_segmentation(t):

        class_segments = tf.constant([0,0,1,1,2,2])

        a = tf.math.segment_mean(t, class_segments, name=None)

        b = tf.math.argmax(a)
  
        left_weights = tf.constant([1.0,1.0,0.0,0.0,0.0,0.0])
        middle_weights = tf.constant([0.0,0.0,1.0,1.0,0.0,0.0])
        right_weights = tf.constant([0.0,0.0,0.0,0.0,1.0,1.0])
        zero_weights = tf.constant([0.0,0.0,0.0,0.0,0.0,0.0])

        c = tf.cond(tf.math.equal(b,0), lambda: tf.math.add(t, left_weights), lambda: zero_weights)
        d = tf.cond(tf.math.equal(b,1), lambda: tf.math.add(t, middle_weights ), lambda: zero_weights)
        e = tf.cond(tf.math.equal(b,2), lambda: tf.math.add(t, right_weights), lambda: zero_weights)

        f = tf.math.add_n([c,d,e])
        print("Tensor shape: ", f.shape) # returns "Unknown"
        
        return f

Upvotes: 1

Views: 522

Answers (1)

Vlad
Vlad

Reputation: 8595

You have a few problems in your code.

  1. tf.math.segment_mean() expects class_segments to have the same shape as first dimension of your input t. So None must be equal 6 in order for your code to run. This is most likely cause of you getting the unknown shape - because the shape of your tensors depends on None which is determined on runtime. You could apply transformation for your code to run (not sure if that is what you are trying to achieve), eg.
a = tf.math.segment_mean(tf.transpose(t), class_segments)
  1. In tf.cond() true_fn and false_fn must return tensors of same shape. In your case true_fn returns (None, 6) because of broadcasting and false_fn returns tensor of shape (6,).
  2. The predicate in tf.cond() must be reduced to a rank 0. For example, if you were to apply b = tf.math.argmax(tf.math.segment_mean(tf.transpose(t), class_segments), 0) then the shape of b would be (None) and the predicate pred in tf.cond() will be broadcasted to the same shape (which will raise an error).

Upvotes: 1

Related Questions