msm
msm

Reputation: 245

Creating and train only specified weights in TensorFlow or PyTorch

I am wondering if there is a way in TensorFlow, PyTorch or some other library to selectively connect neurons. I want to make a network with a very large number of neurons in each layer, but that has very few connections between layers.

Note that I do not think this is a duplicate of this answer: Selectively zero weights in TensorFlow?. I implemented a custom keras layer using essentially the same method that appears in that question - essentially by creating a dense layer where all but the specified weights are ignored in training and evaluation. This fulfills part of what I want to do by not training specified weights, and not using them for prediction. But, the problems is that I still waste memory saving the untrained weights, and I waste time calculating the gradients of the zeroed weights. What I would like is for the computation of the gradient matrices to involve only sparse matrices, so that I do not waste time and memory.

Is there a way to selectively create and train weights without wasting memory? If my question is unclear or there is more information that it would be helpful for me to provide, please let me know. I would like to be helpful as a question-asker.

Upvotes: 1

Views: 1266

Answers (2)

Florian Drawitsch
Florian Drawitsch

Reputation: 715

Both tensorflow and pytorch support sparse tensors (torch.sparse, tf.sparse).

My intuitive understanding would be that if you were willing to write your network using the respective low level APIs (e.g. actually implementing the forward-pass yourself), you could cast your weight matrices as sparse tensors. That would in turn result in sparse connectivity, since the weight matrix of layer [L] defines the connectivity between neurons of the previous layer [L-1] with neurons of layer [L].

Upvotes: 1

cheersmate
cheersmate

Reputation: 2656

The usual, simple solution is to initialize your weight matrices to have zeros where there should be no connection. You store a mask of the location of these zeros, and set the weights at these positions to zero after each weight update. You need to do this as the gradient for zero weights may be nonzero, and this would introduce nonzero weights (i.e. connectios) where you don't want any.

Pseudocode:

# setup network
weights = sparse_init()  # only nonzero for existing connections
zero_mask = where(weights == 0)

# train
for e in range(num_epochs):
    train_operation()  # may lead to introduction of new connections
    weights[zero_mask] = 0  # so we set them to zero again

Upvotes: 2

Related Questions