Luca
Luca

Reputation: 51

Update a subset of weights in TensorFlow

Does anyone know how to update a subset (i.e. only some indices) of the weights that are used in the forward propagation?

My guess is that I might be able to do that after applying compute_gradients as follows:

optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
grads_vars = optimizer.compute_gradients(loss, var_list=[weights, bias_h, bias_v])

...and then do something with the list of tuples in grads_vars.

Upvotes: 5

Views: 6121

Answers (3)

Crazy_LittleBoy
Crazy_LittleBoy

Reputation: 11

# in TF2.0 you can solve with "tensor_scatter_nd_update"
# for example:
tensor = [0, 0, 0, 0, 0, 0, 0, 0]  # tf.rank(tensor) == 1
indices = [[1], [3], [4], [7]]  # num_updates == 4, index_depth == 1
updates = [9, 10, 11, 12]  # num_updates == 4
print(tf.tensor_scatter_nd_update(tensor, indices, updates))
# tf.Tensor([ 0 9  0 10  11  0  0 12], shape=(8,), dtype=int32)

Upvotes: 0

user38157
user38157

Reputation: 341

Easiest way is to pull the tf.Variable into python (as a numpy array) using npvar = sess.run(tfvar), then perform some operation on it such as npvar[1, 2] = -10. Then you can upload the modified data back into tensorflow using sess.run(tfvar.assign(npvar)).

Obviously this is very slow and not really useful for training but it does work.

Upvotes: 1

Yaroslav Bulatov
Yaroslav Bulatov

Reputation: 57913

You could use a combination of gather and scatter_update. Here's an example that doubles the values at position 0 and 2

indices = tf.constant([0,2])
data = tf.Variable([1,2,3])
data_subset = tf.gather(data, indices)
updated_data_subset = 2*data_subset
sparse_update = tf.scatter_update(data, indices, updated_data_subset)
init_op = tf.initialize_all_variables()

sess = tf.Session()
sess.run([init_op])
print "Values before:", sess.run([data])
sess.run([sparse_update])
print "Values after:", sess.run([data])

You should see

Values before: [array([1, 2, 3], dtype=int32)]
Values after: [array([2, 2, 6], dtype=int32)]

Upvotes: 9

Related Questions