Reputation: 563
I am training a neural network using pytorch, and I want to save the weights at every iteration. In other words, I want to create a list that contains all the weights the neural network has had during training.
I did the following:
for i, (images, labels) in enumerate(train_loader):
(.....code that is used to train the model here.....)
weight = model.fc2.weight.detach().numpy()
weights_list.append(weight)
When I then print the entries of the list 'weights_list', I notice that they are all the same, which cannot be true, because I have printed the weights during the training and they do change (and the network actually does learn, so they have to). My guess is that every entry of the list is actually a pointer to the weights of the network at the moment the list is checked. So:
1) Is my guess correct? 2) How can I solve this problem?
Thank you!
Upvotes: 1
Views: 894
Reputation: 22214
The functionality to save and load weights is built in. To save to a file you can use
torch.save('checkpoint.pt', model.state_dict())
and to load you can use
model.load_state_dict(torch.load('checkpoint.pt'))
That said, converting to numpy doesn't necessarily create a copy. For example if you have a numpy array y
and want to create a copy you could use
x = numpy.copy(y)
Upvotes: 1