learner
learner

Reputation: 3472

Saving the model architecture with activation functions in PyTorch

I use PyTorch for training neural networks. While saving the model, the weights of the network are saved, while the activation functions are not captured. Now, I reload the model from the saved weights with the activation functions changed, the model load still does not throw error. Further, the network outputs incorrect values (obviously). Is there a way to save the structure of the neural network along with the weights? An MWE is presented below.

import torch
from torch import nn

class Test(nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        self.fc1 = nn.Linear(10, 25)
        self.fc2 = nn.Linear(25, 10)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
    def forward(self, inputs):
        return self.tanh(self.fc2(self.relu(self.fc1(inputs))))

To save

test = Test().float()
torch.save(test.state_dict(), "test.pt")

To load

import torch
from torch import nn

class Test1(nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        self.fc1 = nn.Linear(10, 25)
        self.fc2 = nn.Linear(25, 10)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
    def forward(self, inputs):
        return self.relu(self.fc2(self.tanh(self.fc1(inputs))))

test1 = Test1().float()
test1.load_state_dict(torch.load("test.pt"))  # Loads without error. However the activation functions, tanh and relu are interchanged, and the network outputs incorrect values.

Is there a way to also capture the activation functions, while saving? Thanks.

Upvotes: 2

Views: 359

Answers (1)

Miłosz Bertman
Miłosz Bertman

Reputation: 1

Well to my understanding you have to load the state_dict directly into the same class from which it was saved. Just like loading pickled objects. Right now you have 2 different objects which are Test1 and Test2.

There might be some pythonic way of saving also an object architecture and its properties into one file, but I can imagine it would be quite a hustle to do so which would have to be included in many additional functions, etc.

Upvotes: 0

Related Questions