Reputation: 8352
I have an index vector that may contain negative entries. How can I use this in tf.gather
? My approach
params = tf.constant(range(5))
idx = tf.constant([-1, 1, 2])
tf.where(
condition = idx >= 0,
x = tf.gather(params, idx),
y = -1
)
throws
InvalidArgumentError: indices[0] = -1 is not in [0, 5) [Op:GatherV2]
because the x
branch is evaluated for all elements. I do not want to remove the invalid indices because I need to retain the positional information, i.e. the desired output is [-1, 1, 2]
(rather than [1, 2]
, which I would get by discarding the invalid indices).
Upvotes: 0
Views: 556
Reputation: 1651
You can do it as follows
tf.where(idx >= 0, tf.gather(params, tf.where(idx >= 0, idx, 0)), -1)
Output
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([-1, 1, 2])>
Upvotes: 1