BuDiu
BuDiu

Reputation: 1

How do I save custom functions and parameters in PyTorch?

Firstly, the network function is defined:

def softmax(X):
    X_exp=torch.exp(X)
    partition=X_exp.sum(1,keepdim=True)
    return X_exp/partition

def net(X):
    return softmax(torch.matmul(X.reshape(-1,W.shape[0]),W)+b)

Then update the function parameters by training

train(net,train_iter,test_iter,cross_entropy,num_epoches,updater)

Finally, the function is saved and loaded for prediction

PATH='./net.pth'
torch.save(net,PATH)
saved_net=torch.load(PATH)
predict(saved_net,test_iter,6)

The prediction results show that the updated parameters W and b are not saved and loaded. What is the correct way to save custom functions and updated parameters ?

Upvotes: 0

Views: 1002

Answers (1)

Ivan
Ivan

Reputation: 40768

The correct way is to implement your own nn.Module and then use the provided utilities to save and load the model's state (their weights) on demand.

You must define two functions:

  • __init__: the class initializer logic where you define your model's parameters.

  • forward: the function which implements the model's forward pass.

A minimal example would be of the form:

class LinearSoftmax(nn.Module):
    def __init__(self, in_feat, out_feat):
        super().__init__()
        self.W = torch.rand(in_feat, out_feat)
        self.b = torch.rand(out_feat)

    def softmax(X):
        X_exp = torch.exp(X)
        partition = X_exp.sum(1, keepdim=True)
        return X_exp / partition

    def forward(X):
        return softmax(torch.matmul(X.reshape(-1,W.shape[0]),W)+b)

You can initialize a new model by doing:

>>> model = LinearSoftmax(10, 3)

You can then save and load weights W and b of a given instance:

Upvotes: 1

Related Questions