mon
mon

Reputation: 22368

Tensorflow 2 - How to conditionally update values directly in tf.Variable

Question

Please advise how to conditionally update the original tf.Variable. This is a different question from conditional assignment of tf.variable in Tensorflow 2.

Background

tf.Variable is mutable and the assign method would update the same memory area. It looks the assign method does not have an option to incorporate condition when assigning values. Hence I suppose tf.where is to update tf.Variable conditionally.

Returns:
    If x and y are provided: 
        A Tensor with the same type as x and y, and shape that is broadcast from the condition, x, and y.
    Otherwise: 
        A Tensor with shape (num_true, dim_size(condition)).

In numpy, indexing can be used to directly update the numpy array but there looks no such way in Tensorflow.

# Numpy conditional uddate with boolean indexing
x = np.random.uniform(-1, 1, size=(3, 4))
x[x > 0] = 0

Problem

As tf.Variable is mutable, expected that the tf.where will mutate the original Variable, however as below, the original Variable x has not been updated.

x = tf.Variable(np.random.uniform(-1, 1, size=(3,4)), dtype=tf.float32)
print(f"x:{x}\n")
print(f"x > 0:\n{x > 0}\n")
print(f"tf.where(x>0, 1, x):\n{tf.where(x>0, 1, x)}")

x  # check if updated

Result:

# --------------------------------------------------------------------------------
# Original tf.Variable x
# --------------------------------------------------------------------------------
x is <tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
array([[ 0.8015974 ,  0.8223503 , -0.2704468 ,  0.01874248],
       [ 0.46989247,  0.4753061 , -0.06808566, -0.57646054],
       [ 0.07082719,  0.2924774 , -0.12741995,  0.3168819 ]],
      dtype=float32)>

x > 0 is [[ True  True False  True]
 [ True  True False False]
 [ True  True False  True]]

# --------------------------------------------------------------------------------
# Update using tf.where
# --------------------------------------------------------------------------------
tf.where(x>0, 1, x)=
[[ 1.          1.         -0.2704468   1.        ]
 [ 1.          1.         -0.06808566 -0.57646054]
 [ 1.          1.         -0.12741995  1.        ]]

# --------------------------------------------------------------------------------
# The x is the same as before.
# --------------------------------------------------------------------------------
<tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
array([[ 0.8015974 ,  0.8223503 , -0.2704468 ,  0.01874248],
       [ 0.46989247,  0.4753061 , -0.06808566, -0.57646054],
       [ 0.07082719,  0.2924774 , -0.12741995,  0.3168819 ]],
      dtype=float32)>

Please help me understand if there is a way to directly update x.


Notes

First, the return value from the slide op is a Tensor, which is not the original variable.

var_slice = var[4:5]
var_slice.assign(math_ops.sub(var, const))

With a tensor, it doesn't have methods like "assign" or "assign_add" etc.

I think the most feasible way to create a TensorArray, which contains your slides of mean/variance value, and read/write to it within the call() body for cell. Once the layer goes over all the timesteps, you can do a stack for the TensorArray and assign the value back to the variable itself. You don't have to continuously write to the variable when processing the timesteps, since timestep t shouldn't affect the result in t+1 (if I understand your problem correctly)

Upvotes: 1

Views: 592

Answers (2)

Innat
Innat

Reputation: 17239

Here are a few more cases. The main takeaway is that to update the content in tf.Variable, we use assign. But note that it will assign the value eagerly.

x = tf.Variable(range(5), shape=(5,), name="a")
x.numpy()
array([0, 1, 2, 3, 4], dtype=int32)

x.assign(tf.where(x > 2, x, -1)).numpy()
array([-1, -1, -1,  3,  4], dtype=int32)

x.assign(range(5,10)).numpy()
array([5, 6, 7, 8, 9], dtype=int32)

x.assign(tf.where(x < 8, tf.range(5), tf.where(x > 8 , x, -1))).numpy()
array([ 0,  1,  2, -1,  9], dtype=int32)

Upvotes: 1

Lescurel
Lescurel

Reputation: 11651

To update a tf.Variable, you need to call assign. Read more in the Introduction to Variables guide.

>>> import tensorflow as tf
>>> tf.random.set_seed(0)
>>> x = tf.Variable(tf.random.uniform((3,4),-1,1), dtype=tf.float32)
>>> x
<tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
array([[-0.41604972, -0.5868671 ,  0.07078147,  0.12251496],
       [-0.16665101,  0.6156559 , -0.0135498 ,  0.9962585 ],
       [ 0.3934703 , -0.7492528 ,  0.4196334 ,  0.32483125]],
      dtype=float32)>
>>> x.assign(tf.where(x>0,x,0))
<tf.Variable 'UnreadVariable' shape=(3, 4) dtype=float32, numpy=
array([[0.        , 0.        , 0.07078147, 0.12251496],
       [0.        , 0.6156559 , 0.        , 0.9962585 ],
       [0.3934703 , 0.        , 0.4196334 , 0.32483125]], dtype=float32)>
>>> x
<tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
array([[0.        , 0.        , 0.07078147, 0.12251496],
       [0.        , 0.6156559 , 0.        , 0.9962585 ],
       [0.3934703 , 0.        , 0.4196334 , 0.32483125]], dtype=float32)>

Upvotes: 2

Related Questions