Reputation: 247
The weights in a dense layer of a neural network is a (n,d) matrix, and I want to force some of these weights to always be zero. I have another (n,d) matrix which is the mask of which entries can be non-zero. The idea is that the layer should not be truly dense, but have some connections missing (i.e. equal to 0).
How can achieve this while training with PyTorch (or Tensorflow)? I don't want these weights to become non-zero while training.
One method, if it doesn't support it directly, would be to zero-out the desired entries after each iteration of training.
Upvotes: 5
Views: 1655
Reputation: 469
I'll assume that you want to use dense tensors to implement this kind of sparsely connected layer. If so, you can define the mask matrix (tensor) to be 0.0 for the elements that you want to mask (no connection) and 1.0 otherwise. In your forward pass, you can then simply multiply your weight tensor with the mask tensor (you want to take the element-wise product, which is the default when you use the * operator in PyTorch) before you do the matrix multiplication with the input for your sparse layer.
To make this work properly you have to make sure that your mask tensor does not receive a gradient, otherwise it will be updated and become invalid as you train your model. To do this, you simply set requires_grad=False
when you create the mask tensor (see here).
Upvotes: 1
Reputation: 114796
You can take advantage of pytorch's sparse data type:
class SparseLinear(nn.Module):
def __init__(self, in_features, out_features, sparse_indices):
super(SparseLinear, self).__init__()
self.weight = nn.Parameter(data=torch.sparse.FloatTensor(sparse_indices, torch.randn(sparse_indices.shape[1]), [in_features, out_features]), requires_grad=True)
self.bias = nn.Parameter(data=torch.randn(out_features), requires_grad=True)
def forward(self, x):
return torch.sparse.admm(self.bias, self.weight, x, 1., 1.)
Upvotes: 3