fuenfundachtzig
fuenfundachtzig

Reputation: 8352

How to use tf.gather with index vector that may contain out-of-range indices?

I have an index vector that may contain negative entries. How can I use this in tf.gather? My approach

params = tf.constant(range(5))
idx = tf.constant([-1, 1, 2])
tf.where(
    condition = idx >= 0,
    x = tf.gather(params, idx),
    y = -1
)

throws

InvalidArgumentError: indices[0] = -1 is not in [0, 5) [Op:GatherV2]

because the x branch is evaluated for all elements. I do not want to remove the invalid indices because I need to retain the positional information, i.e. the desired output is [-1, 1, 2] (rather than [1, 2], which I would get by discarding the invalid indices).

Upvotes: 0

Views: 556

Answers (1)

bui
bui

Reputation: 1651

You can do it as follows

tf.where(idx >= 0, tf.gather(params, tf.where(idx >= 0, idx, 0)), -1)

Output

<tf.Tensor: shape=(3,), dtype=int32, numpy=array([-1,  1,  2])>

Upvotes: 1

Related Questions