mon
mon

Reputation: 22284

tensorflow 2 - how to directly update elements in tf.Variable X at indices?

Is there a way to directly update the elements in tf.Variable X at indices without creating a new tensor having the same shape as X?

tf.tensor_scatter_nd_update create a new tensor hence it appears not updateing the original tf.Variable.

This operation creates a new tensor by applying sparse updates to the input tensor.

tf.Variable assign apparently needs a new tensor value which has the same shape of X to update the tf.Variable X.

assign(
    value, use_locking=False, name=None, read_value=True
)

value A Tensor. The new value for this variable.

Upvotes: 1

Views: 1654

Answers (1)

Innat
Innat

Reputation: 17239

About the tf.tensor_scatter_nd_update, you're right that it returns a new tf.tensor (and not tf.Variable). But about the assign which is an attribute of tf.Variable, I think you somewhat misread the document; the value is just the new item that you want to assign in particular indices of your old variable.

AFAIK, in all tensors are immutable like numbers and strings; you can never update the contents of a tensor, only create a new one, source. And directly updating or manipulating of tf.tensor or tf.Variable such as like item assignment is still not supported. Check the following Github issues to follow up the discussions: #33131, #14132.


In , we can do an in-place item assignment that you showed in the comment box.

import numpy as np

a = np.array([1,2,3])  
print(a) # [1 2 3]
a[1] = 0
print(a) # [1 0 3] 

A similar result can be achieved in tf.Variable with assign attribute.

import tensorflow as tf 

b = tf.Variable([1,2,3])
b.numpy() # array([1, 2, 3], dtype=int32)

b[1].assign(0)
b.numpy() # array([1, 0, 3], dtype=int32)

Later, we can convert it to tf. tensor as follows.

b_ten = tf.convert_to_tensor(b)
b_ten.numpy() # array([1, 0, 3], dtype=int32)

We can do such item assignment in tf.tensor too but we need to convert it to tf.Variable first, (I know, not very intuitive).

tensor = [[1, 1], [1, 1], [1, 1]]    # tf.rank(tensor) == 2
indices = [[0, 1], [2, 0]]           # num_updates == 2, index_depth == 2
updates = [5, 10]                    # num_updates == 2
x = tf.tensor_scatter_nd_update(tensor, indices, updates)
x
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[ 1,  5],
       [ 1,  1],
       [10,  1]], dtype=int32)>
x = tf.Variable(x)
x
<tf.Variable 'Variable:0' shape=(3, 2) dtype=int32, numpy=
array([[ 1,  5],
       [ 1,  1],
       [10,  1]], dtype=int32)>

x[0].assign([5,1])
x
<tf.Variable 'Variable:0' shape=(3, 2) dtype=int32, numpy=
array([[ 5,  1],
       [ 1,  1],
       [10,  1]], dtype=int32)>
x = tf.convert_to_tensor(x)
x
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[ 5,  1],
       [ 1,  1],
       [10,  1]], dtype=int32)>

Upvotes: 2

Related Questions