Reputation: 137
I'm trying to use SparseTensor to represent weight variables in a fully-connected layer.
However, it seems that TensorFlow 0.8 doesn't allow to use SparseTensor as tf.Variable.
Is there any way to go around this?
I've tried
import tensorflow as tf
a = tf.constant(1)
b = tf.SparseTensor([[0,0]],[1],[1,1])
print a.__class__ # shows <class 'tensorflow.python.framework.ops.Tensor'>
print b.__class__ # shows <class 'tensorflow.python.framework.ops.SparseTensor'>
tf.Variable(a) # Variable is declared correctly
tf.Variable(b) # Fail
By the way, my ultimate goal of using SparseTensor is to permanently mask some of connections in dense form. Thus, these pruned connections are ignored while calculating and applying gradients.
In my current implementation of MLP, SparseTensor and its sparse form of matmul ops successfully reports inference outputs. However, the weights declared using SparseTensor aren't trained as training steps go.
Upvotes: 6
Views: 5277
Reputation: 21
The above code works with some minor correction like this.
def optimize(loss, mask_tensor):
optimizer = tf.train.AdamOptimizer(0.001)
grads_and_vars = optimizer.compute_gradients(loss)
modified_grads_and_vars = [
(tf.multiply(gv[0], mask_tensor[gv[1]]), gv[1]) for gv in grads_and_vars
]
return optimizer.apply_gradients(modified_grads_and_vars)
Upvotes: 0
Reputation: 11
TensorFlow doesn't support training on sparse tensors yet. You can initialize a sparse tensor as you wish, then convert it into a dense tensor and create a variable from it like that:
# You need to correctly initialize the sparse tensor with indices, values and a shape
b = tf.SparseTensor(indices, values, shape)
b_dense = tf.sparse_tensor_to_dense(b)
b_variable = tf.Variable(b_dense)
Now you have initialized a sparse tensor as a variable. Now you need to take care of the gradient update (in other words, make sure the entries in the variable stay 0, since there is a non-vanishing gradient calculated in the backpropagation algorithm for them when using this naively).
In order to do this, TensorFlow optimizers have a method called tf.train.Optimizer.compute_gradients(loss, [list_of_variables]). This calculates all the gradients in the graph necessary to minimize the loss function, but doesn't apply them yet. This method returns a list of tuples in a form of (gradients, variable). You can modify these gradients freely, but in your case it makes sense to mask the gradients not needed to 0 (i.e. by creating another sparse tensor with default values 0.0 and values 1.0 where the weights in your network are present). After having modified them, you call the optimizer method tf.train.Optimizer.apply_gradients(grads_and_vars) to actually apply the gradients. An example code would look like this:
# Create optimizer instance
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
# Get the gradients for your weights
grads_and_vars = optimizer.compute_gradients(loss, [b_variable])
# Modify the gradients at will
# In your case it would look similar to this
modified_grads_and_vars = [(tf.multiply(gv[0], mask_tensor), gv[1] for gv in grads_and_vars]
# Apply modified gradients to your model
optimizer.apply_gradients(modified_grads_and_vars)
This makes sure your entries stay 0 in your weight matrix and no unwanted connections are created. You need to take care of all the other gradients for all other variables later.
Upvotes: 1
Reputation: 41
As a workaround to your problem, you can provide a tf.Variable
(until Tensorflow v0.8
) for the values of a sparse tensor. The sparsity structure has to be pre-defined in that case, the weights however remain trainable.
weights = tf.Variable(<initial-value>)
sparse_var = tf.SparseTensor(<indices>, weights, <shape>) # v0.8
sparse_var = tf.SparseTensor(<indices>, tf.identity(weights), <shape>) # v0.9
Upvotes: 4
Reputation: 8487
TensorFlow doesn't currently support sparse tensor variables. However, it does support sparse lookups (tf.embedding_lookup
) and sparse gradient updates (tf.sparse_add
) of dense variables. I suspect these two will suffice your use case.
Upvotes: 2