shoseta
shoseta

Reputation: 45

Fill a specific index in tensor with a value

I'm beginner with tensorflow. I created this tensor

z = tf.zeros([20,2], tf.float32)

and I want to change the value of index z[2,1] and z[2,2] to 1.0 instead of zeros. How can I do that?

Upvotes: 4

Views: 8843

Answers (4)

AloneTogether
AloneTogether

Reputation: 26708

A Tensorflow 2.x solution to this problem would look like this:

import tensorflow as tf

z = tf.zeros([20,2], dtype=tf.float32)

index1 = [2, 0]
index2 = [2, 1]

result = tf.tensor_scatter_nd_update(z, [index1, index2], [1.0, 1.0])
tf.print(result)
[[0 0]
 [0 0]
 [1 1]
 ...
 [0 0]
 [0 0]
 [0 0]]

Upvotes: 0

Abhijit Balaji
Abhijit Balaji

Reputation: 1940

A much better way to accomplish this is to use tf.sparse_to_dense.

tf.sparse_to_dense(sparse_indices=[[0, 0], [1, 2]],
                   output_shape=[3, 4],
                   default_value=0,
                   sparse_values=1,
                   )

Output:

[[1, 0, 0, 0]
 [0, 0, 1, 0]
 [0, 0, 0, 0]]

However, tf.sparse_to_dense is deprecated recently. Thus, use tf.SparseTensor and then use tf.sparse.to_dense to get the same result as above

Upvotes: 2

a.ewais
a.ewais

Reputation: 43

an easy way:

import numpy as np
import tensorflow as tf

init = np.zeros((20,2), np.float32)
init[2,1] = 1.0
z = tf.variable(init)

or use tf.scatter_update(ref, indices, updates) https://www.tensorflow.org/api_docs/python/tf/scatter_update

Upvotes: 1

Maxim
Maxim

Reputation: 53778

What you exactly ask is not possible for two reasons:

  • z is a constant tensor, it can't be changed.
  • There is no z[2,2], only z[2,0] and z[2,1].

But assuming you want to change z to a variable and fix the indices, it can be done this way:

z = tf.Variable(tf.zeros([20,2], tf.float32))  # a variable, not a const
assign21 = tf.assign(z[2, 0], 1.0)             # an op to update z
assign22 = tf.assign(z[2, 1], 1.0)             # an op to update z

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  print(sess.run(z))                           # prints all zeros
  sess.run([assign21, assign22])
  print(sess.run(z))                           # prints 1.0 in the 3d row

Upvotes: 6

Related Questions