Reputation: 760
Below is the function that I am passing to a keras Lambda layer.
I am getting a problem with the output of tf.cond()
. It returns a shape of <unknown>
. The input tensor (t) and the constant weight tensor have shapes of (None,6)
and (6,)
, respectively. When I add these two outside of tf.cond()
then I get a tensor of shape (None,6)
, which is what I need it to be. However, when the same add operation is returned from within tf.cond()
, I get a tensor of shape <unknown>
.
What changes when this operation goes via tf.cond()
.
def class_segmentation(t):
class_segments = tf.constant([0,0,1,1,2,2])
a = tf.math.segment_mean(t, class_segments, name=None)
b = tf.math.argmax(a)
left_weights = tf.constant([1.0,1.0,0.0,0.0,0.0,0.0])
middle_weights = tf.constant([0.0,0.0,1.0,1.0,0.0,0.0])
right_weights = tf.constant([0.0,0.0,0.0,0.0,1.0,1.0])
zero_weights = tf.constant([0.0,0.0,0.0,0.0,0.0,0.0])
c = tf.cond(tf.math.equal(b,0), lambda: tf.math.add(t, left_weights), lambda: zero_weights)
d = tf.cond(tf.math.equal(b,1), lambda: tf.math.add(t, middle_weights ), lambda: zero_weights)
e = tf.cond(tf.math.equal(b,2), lambda: tf.math.add(t, right_weights), lambda: zero_weights)
f = tf.math.add_n([c,d,e])
print("Tensor shape: ", f.shape) # returns "Unknown"
return f
Upvotes: 1
Views: 522
Reputation: 8595
You have a few problems in your code.
tf.math.segment_mean()
expects class_segments
to have the same shape as first dimension of your input t
. So None
must be equal 6
in order for your code to run. This is most likely cause of you getting the unknown shape - because the shape of your tensors depends on None
which is determined on runtime. You could apply transformation for your code to run (not sure if that is what you are trying to achieve), eg.a = tf.math.segment_mean(tf.transpose(t), class_segments)
tf.cond()
true_fn
and false_fn
must return tensors of same shape. In your case true_fn
returns (None, 6)
because of broadcasting and false_fn
returns tensor of shape (6,)
.tf.cond()
must be reduced to a rank 0. For example, if you were to apply
b = tf.math.argmax(tf.math.segment_mean(tf.transpose(t), class_segments), 0)
then the shape of b
would be (None)
and the predicate pred
in tf.cond()
will be broadcasted to the same shape (which will raise an error).Upvotes: 1