Reputation: 1
Firstly, the network function is defined:
def softmax(X):
X_exp=torch.exp(X)
partition=X_exp.sum(1,keepdim=True)
return X_exp/partition
def net(X):
return softmax(torch.matmul(X.reshape(-1,W.shape[0]),W)+b)
Then update the function parameters by training
train(net,train_iter,test_iter,cross_entropy,num_epoches,updater)
Finally, the function is saved and loaded for prediction
PATH='./net.pth'
torch.save(net,PATH)
saved_net=torch.load(PATH)
predict(saved_net,test_iter,6)
The prediction results show that the updated parameters W and b are not saved and loaded. What is the correct way to save custom functions and updated parameters ?
Upvotes: 0
Views: 1002
Reputation: 40768
The correct way is to implement your own nn.Module
and then use the provided utilities to save and load the model's state (their weights) on demand.
You must define two functions:
__init__
: the class initializer logic where you define your model's parameters.
forward
: the function which implements the model's forward pass.
A minimal example would be of the form:
class LinearSoftmax(nn.Module):
def __init__(self, in_feat, out_feat):
super().__init__()
self.W = torch.rand(in_feat, out_feat)
self.b = torch.rand(out_feat)
def softmax(X):
X_exp = torch.exp(X)
partition = X_exp.sum(1, keepdim=True)
return X_exp / partition
def forward(X):
return softmax(torch.matmul(X.reshape(-1,W.shape[0]),W)+b)
You can initialize a new model by doing:
>>> model = LinearSoftmax(10, 3)
You can then save and load weights W
and b
of a given instance:
save the dictionary returned by nn.Module.state_dict
with torch.save
:
>>> torch.save(model.state_dict(), PATH)
load the weight into memory with torch.load
and mount on model with nn.Module.load_state_dict
>>> model.load_state_dict(torch.load(PATH))
Upvotes: 1