miho
miho

Reputation: 12085

How to update a single column of a 2d tf.Variable?

Let's say I've got a MxN-shaped tf.Variable which stores some state of my custom layer:

import tensorflow as tf
m, n = 3, 4  # just for example
v = tf.Variable(tf.zeros([m, n]), trainable=False)

# v = <tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
#     array([[0., 0., 0., 0.],
#            [0., 0., 0., 0.],
#            [0., 0., 0., 0.]], dtype=float32)>

I am aware that I could update the values of this variable with v.assign(...), but how can I update just a sub-section of this variable? For example, I would like to insert a given vector at a given column.

x = tf.ones([m,1])
c = tf.Variable(2)
# update v by inserting x at column c

...such that the following would be the new values of v:

# v = <tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
#     array([[0., 0., 1., 0.],
#            [0., 0., 1., 0.],
#            [0., 0., 1., 0.]], dtype=float32)>

Upvotes: 0

Views: 295

Answers (1)

Marco Cerliani
Marco Cerliani

Reputation: 22031

with TF 2.2

m, n = 3, 4  # just for example
v = tf.Variable(tf.zeros([m, n]), trainable=False)
x = tf.ones(m)
c = 2

change_v = v[:,c].assign(x)

Upvotes: 1

Related Questions