Reputation: 12085
Let's say I've got a MxN-shaped tf.Variable
which stores some state of my custom layer:
import tensorflow as tf
m, n = 3, 4 # just for example
v = tf.Variable(tf.zeros([m, n]), trainable=False)
# v = <tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
# array([[0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.]], dtype=float32)>
I am aware that I could update the values of this variable with v.assign(...)
, but how can I update just a sub-section of this variable? For example, I would like to insert a given vector at a given column.
x = tf.ones([m,1])
c = tf.Variable(2)
# update v by inserting x at column c
...such that the following would be the new values of v
:
# v = <tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
# array([[0., 0., 1., 0.],
# [0., 0., 1., 0.],
# [0., 0., 1., 0.]], dtype=float32)>
Upvotes: 0
Views: 295
Reputation: 22031
with TF 2.2
m, n = 3, 4 # just for example
v = tf.Variable(tf.zeros([m, n]), trainable=False)
x = tf.ones(m)
c = 2
change_v = v[:,c].assign(x)
Upvotes: 1