Reputation: 22284
Is there a way to directly update the elements in tf.Variable
X at indices without creating a new tensor having the same shape as X?
tf.tensor_scatter_nd_update create a new tensor hence it appears not updateing the original tf.Variable.
This operation creates a new tensor by applying sparse updates to the input tensor.
tf.Variable assign apparently needs a new tensor value
which has the same shape of X to update the tf.Variable X.
assign(
value, use_locking=False, name=None, read_value=True
)
value A Tensor. The new value for this variable.
Upvotes: 1
Views: 1654
Reputation: 17239
About the tf.tensor_scatter_nd_update
, you're right that it returns a new tf.tensor (and not tf.Variable). But about the assign
which is an attribute of tf.Variable, I think you somewhat misread the document; the value
is just the new item that you want to assign in particular indices of your old variable.
AFAIK, in tensorflow all tensors are immutable like python numbers and strings; you can never update the contents of a tensor, only create a new one, source. And directly updating or manipulating of tf.tensor
or tf.Variable
such as numpy like item assignment is still not supported. Check the following Github issues to follow up the discussions: #33131, #14132.
In numpy, we can do an in-place item assignment that you showed in the comment box.
import numpy as np
a = np.array([1,2,3])
print(a) # [1 2 3]
a[1] = 0
print(a) # [1 0 3]
A similar result can be achieved in tf.Variable with assign
attribute.
import tensorflow as tf
b = tf.Variable([1,2,3])
b.numpy() # array([1, 2, 3], dtype=int32)
b[1].assign(0)
b.numpy() # array([1, 0, 3], dtype=int32)
Later, we can convert it to tf. tensor as follows.
b_ten = tf.convert_to_tensor(b)
b_ten.numpy() # array([1, 0, 3], dtype=int32)
We can do such item assignment in tf.tensor too but we need to convert it to tf.Variable first, (I know, not very intuitive).
tensor = [[1, 1], [1, 1], [1, 1]] # tf.rank(tensor) == 2
indices = [[0, 1], [2, 0]] # num_updates == 2, index_depth == 2
updates = [5, 10] # num_updates == 2
x = tf.tensor_scatter_nd_update(tensor, indices, updates)
x
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[ 1, 5],
[ 1, 1],
[10, 1]], dtype=int32)>
x = tf.Variable(x)
x
<tf.Variable 'Variable:0' shape=(3, 2) dtype=int32, numpy=
array([[ 1, 5],
[ 1, 1],
[10, 1]], dtype=int32)>
x[0].assign([5,1])
x
<tf.Variable 'Variable:0' shape=(3, 2) dtype=int32, numpy=
array([[ 5, 1],
[ 1, 1],
[10, 1]], dtype=int32)>
x = tf.convert_to_tensor(x)
x
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[ 5, 1],
[ 1, 1],
[10, 1]], dtype=int32)>
Upvotes: 2