Anton
Anton

Reputation: 390

Tensorflow: Attempting to access indexes of a Variable via tensor value?

I am trying to mimic the following behavior of numpy in TensorFlow.

z = np.zeros(2 * 10 - 1, dtype=np.float32)
z[[2,10]] = 1

what I have

test = tf.Variable(tf.zeros(2 * 10 - 1, dtype=tf.float32))
test[tf.constant([2,10])].assign(1)

I need the variable thus cannot just use the constant zeros.

When attempting this I get an error

InvalidArgumentError: Shape must be rank 1 but is rank 2 for 'strided_slice' (op: 'StridedSlice') with input shapes: [19], [1,2], [1,2], [1].

But this doesn't make sense since the index I am providing is of rank 1, and it gets reshaped for some reason.

How would I mimic the behavior above?

Upvotes: 0

Views: 56

Answers (1)

Anton
Anton

Reputation: 390

This seems the closest to what I was looking for, but I am worried that this creates a duplicate value and if test is big, then this will be huge.

tf.scatter_add(test,[2,10],1)

Better answers are welcome.

Upvotes: 1

Related Questions