Reputation: 22368
Please advise how to conditionally update the original tf.Variable. This is a different question from conditional assignment of tf.variable in Tensorflow 2.
tf.Variable
is mutable and the assign
method would update the same memory area. It looks the assign
method does not have an option to incorporate condition when assigning values. Hence I suppose tf.where is to update tf.Variable conditionally.
Returns:
If x and y are provided:
A Tensor with the same type as x and y, and shape that is broadcast from the condition, x, and y.
Otherwise:
A Tensor with shape (num_true, dim_size(condition)).
In numpy, indexing can be used to directly update the numpy array but there looks no such way in Tensorflow.
# Numpy conditional uddate with boolean indexing
x = np.random.uniform(-1, 1, size=(3, 4))
x[x > 0] = 0
As tf.Variable
is mutable, expected that the tf.where
will mutate the original Variable, however as below, the original Variable x
has not been updated.
x = tf.Variable(np.random.uniform(-1, 1, size=(3,4)), dtype=tf.float32)
print(f"x:{x}\n")
print(f"x > 0:\n{x > 0}\n")
print(f"tf.where(x>0, 1, x):\n{tf.where(x>0, 1, x)}")
x # check if updated
Result:
# --------------------------------------------------------------------------------
# Original tf.Variable x
# --------------------------------------------------------------------------------
x is <tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
array([[ 0.8015974 , 0.8223503 , -0.2704468 , 0.01874248],
[ 0.46989247, 0.4753061 , -0.06808566, -0.57646054],
[ 0.07082719, 0.2924774 , -0.12741995, 0.3168819 ]],
dtype=float32)>
x > 0 is [[ True True False True]
[ True True False False]
[ True True False True]]
# --------------------------------------------------------------------------------
# Update using tf.where
# --------------------------------------------------------------------------------
tf.where(x>0, 1, x)=
[[ 1. 1. -0.2704468 1. ]
[ 1. 1. -0.06808566 -0.57646054]
[ 1. 1. -0.12741995 1. ]]
# --------------------------------------------------------------------------------
# The x is the same as before.
# --------------------------------------------------------------------------------
<tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
array([[ 0.8015974 , 0.8223503 , -0.2704468 , 0.01874248],
[ 0.46989247, 0.4753061 , -0.06808566, -0.57646054],
[ 0.07082719, 0.2924774 , -0.12741995, 0.3168819 ]],
dtype=float32)>
Please help me understand if there is a way to directly update x
.
First, the return value from the slide op is a Tensor, which is not the original variable.
var_slice = var[4:5] var_slice.assign(math_ops.sub(var, const))
With a tensor, it doesn't have methods like "assign" or "assign_add" etc.
I think the most feasible way to create a TensorArray, which contains your slides of mean/variance value, and read/write to it within the call() body for cell. Once the layer goes over all the timesteps, you can do a stack for the TensorArray and assign the value back to the variable itself. You don't have to continuously write to the variable when processing the timesteps, since timestep t shouldn't affect the result in t+1 (if I understand your problem correctly)
Upvotes: 1
Views: 592
Reputation: 17239
Here are a few more cases. The main takeaway is that to update the content in tf.Variable
, we use assign
. But note that it will assign the value eagerly.
x = tf.Variable(range(5), shape=(5,), name="a")
x.numpy()
array([0, 1, 2, 3, 4], dtype=int32)
x.assign(tf.where(x > 2, x, -1)).numpy()
array([-1, -1, -1, 3, 4], dtype=int32)
x.assign(range(5,10)).numpy()
array([5, 6, 7, 8, 9], dtype=int32)
x.assign(tf.where(x < 8, tf.range(5), tf.where(x > 8 , x, -1))).numpy()
array([ 0, 1, 2, -1, 9], dtype=int32)
Upvotes: 1
Reputation: 11651
To update a tf.Variable
, you need to call assign
. Read more in the Introduction to Variables guide.
>>> import tensorflow as tf
>>> tf.random.set_seed(0)
>>> x = tf.Variable(tf.random.uniform((3,4),-1,1), dtype=tf.float32)
>>> x
<tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
array([[-0.41604972, -0.5868671 , 0.07078147, 0.12251496],
[-0.16665101, 0.6156559 , -0.0135498 , 0.9962585 ],
[ 0.3934703 , -0.7492528 , 0.4196334 , 0.32483125]],
dtype=float32)>
>>> x.assign(tf.where(x>0,x,0))
<tf.Variable 'UnreadVariable' shape=(3, 4) dtype=float32, numpy=
array([[0. , 0. , 0.07078147, 0.12251496],
[0. , 0.6156559 , 0. , 0.9962585 ],
[0.3934703 , 0. , 0.4196334 , 0.32483125]], dtype=float32)>
>>> x
<tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
array([[0. , 0. , 0.07078147, 0.12251496],
[0. , 0.6156559 , 0. , 0.9962585 ],
[0.3934703 , 0. , 0.4196334 , 0.32483125]], dtype=float32)>
Upvotes: 2