Reputation: 63
I want to implement a function that takes in a variable as input, mutate some of its rows or columns and replaces them back in the original variable. I am able to implement it for row slices using tf.gather and tf.scatter_update but unable to do so for column slices since apparently tf.scatter_update only updates the row slices and does not have an axis feature. I am not an expert in tensorflow therefore I may be missing something. Can someone help?
def matrix_reg(t, percent_t, beta):
''' Takes a variable tensor t as input and regularizes some of its rows.
The number of rows to be regularized are specified by the percent_t. Returns the original tensor by updating its rows indexed by row_ind.
Arguments:
t -- input tensor
percent_t -- percentage of the total rows
beta -- the regularization factor
Output:
the regularized tensor
'''
row_ind = np.random.choice(int(t.shape[0]), int(percent_t*int(t.shape[0])), replace = False)
t_ = tf.gather(t,row_ind)
t_reg = (1+beta)*t_-beta*(tf.matmul(tf.matmul(t_,tf.transpose(t_)),t_))
return tf.scatter_update(t, row_ind, t_reg)
Upvotes: 1
Views: 2227
Reputation: 327
Refer to the Tensorflow2 documentation for tf.Variable
__getitem__
( var, slice_spec )Creates a slice helper object given a variable.
This allows creating a sub-tensor from part of the current contents of a variable. See tf.Tensor.getitem for detailed examples of slicing.
This function in addition also allows assignment to a sliced range. This is similar to
__setitem__
functionality in Python. However, the syntax is different so that the user can capture the assignment operation for grouping or passing to sess.run(). For example,...
Here is a minimal working example:
import tensorflow as tf
import numpy as np
var = tf.Variable(np.random.rand(3,3,3))
print(var)
# update the last column of the three (3x3) matrices to random integer values
# note that the update values needs to have the same shape
# as broadcasting is not supported as of TF2
var[:,:,2].assign(np.random.randint(10,size=(3,3)))
print(var)
Upvotes: 0
Reputation: 59731
Here is a small demonstration of how to update rows or columns. The idea is that you specify the row and column indices of the variables where you want each element in the update to end up. That is easy to do with tf.meshgrid
.
import tensorflow as tf
var = tf.get_variable('var', [4, 3], tf.float32, initializer=tf.zeros_initializer())
updates = tf.placeholder(tf.float32, [None, None])
indices = tf.placeholder(tf.int32, [None])
# Update rows
var_update_rows = tf.scatter_update(var, indices, updates)
# Update columns
col_indices_nd = tf.stack(tf.meshgrid(tf.range(tf.shape(var)[0]), indices, indexing='ij'), axis=-1)
var_update_cols = tf.scatter_nd_update(var, col_indices_nd, updates)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print('Rows updated:')
print(sess.run(var_update_rows, feed_dict={updates: [[1, 2, 3], [4, 5, 6]], indices: [3, 1]}))
print('Columns updated:')
print(sess.run(var_update_cols, feed_dict={updates: [[1, 5], [2, 6], [3, 7], [4, 8]], indices: [0, 2]}))
Output:
Rows updated:
[[0. 0. 0.]
[4. 5. 6.]
[0. 0. 0.]
[1. 2. 3.]]
Columns updated:
[[1. 0. 5.]
[2. 5. 6.]
[3. 0. 7.]
[4. 2. 8.]]
Upvotes: 2