Reputation: 3472
I use PyTorch for training neural networks. While saving the model, the weights of the network are saved, while the activation functions are not captured. Now, I reload the model from the saved weights with the activation functions changed, the model load still does not throw error. Further, the network outputs incorrect values (obviously). Is there a way to save the structure of the neural network along with the weights? An MWE is presented below.
import torch
from torch import nn
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.fc1 = nn.Linear(10, 25)
self.fc2 = nn.Linear(25, 10)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def forward(self, inputs):
return self.tanh(self.fc2(self.relu(self.fc1(inputs))))
To save
test = Test().float()
torch.save(test.state_dict(), "test.pt")
To load
import torch
from torch import nn
class Test1(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.fc1 = nn.Linear(10, 25)
self.fc2 = nn.Linear(25, 10)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def forward(self, inputs):
return self.relu(self.fc2(self.tanh(self.fc1(inputs))))
test1 = Test1().float()
test1.load_state_dict(torch.load("test.pt")) # Loads without error. However the activation functions, tanh and relu are interchanged, and the network outputs incorrect values.
Is there a way to also capture the activation functions, while saving? Thanks.
Upvotes: 2
Views: 359
Reputation: 1
Well to my understanding you have to load the state_dict directly into the same class from which it was saved. Just like loading pickled objects. Right now you have 2 different objects which are Test1 and Test2.
There might be some pythonic way of saving also an object architecture and its properties into one file, but I can imagine it would be quite a hustle to do so which would have to be included in many additional functions, etc.
Upvotes: 0