Reputation: 2441
I'm interested in training both a CNN model and a simple linear feed forward model in PyTorch, and after training to add more filters -- to the CNN layers, & neurons -- to the linear model layers and the outputs (e.g. from binary classification to multiclass classification) of both. By adding them I specifically mean to keep the weights that were trained constant, and to add random initialized weights to the new, incoming weights.
There's an example of a CNN model here, and an example of a simple linear feed forward model here
Upvotes: 0
Views: 399
Reputation: 24814
This one was a bit tricky and requires slice
(see this answer for more info about slice
, but it should be intuitive). Also this answer for slice trick. Please see comments for explanation:
import torch
def expand(
original: torch.nn.Module,
*args,
**kwargs
# Add other arguments if needed, like different stride
# They won't change weights shape, but may change behaviour
):
new = type(original)(*args, **kwargs)
new_weight_shape = torch.tensor(new.weight.shape)
new_bias_shape = torch.tensor(new.bias.shape)
original_weight_shape = torch.tensor(original.weight.shape)
original_bias_shape = torch.tensor(original.bias.shape)
# I assume bias and weight exist, if not, do some checks
# Also quick check, that new layer is "larger" than original
assert torch.all(new_weight_shape >= original_weight_shape)
assert new_bias_shape >= original_bias_shape
# All the weights will be inputted from top to bottom, bias 1D assumed
new.bias.data[:original_bias_shape] = original.bias.data
# Create slices 0:n for each dimension
slicer = tuple([slice(0, dim) for dim in original_weight_shape])
# And input the data
new.weight.data[slicer] = original.weight.data
return new
layer = torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3)
new = expand(layer, in_channels=32, out_channels=64, kernel_size=3)
This should work for any layer (which has weight
and bias
, adjust if needed). Using this approach you can recreate your neural network or use PyTorch's apply
(docs here)
Also remember, that you have to explicitly pass creational *args
and **kwargs
for "new layer" which will have trained connections inputted.
Upvotes: 1