Reputation: 1199
I have trained a model in PyTorch - an RCNN for text classification. The model has very high precision and recall, but I may eventually receive new documents with text unlike what I used to train, validate, or test the model.
I would like to add new text samples to the model without retraining the model from the beginning. This is desirable because I may lose access to some of the text used for initial training.
If it is not possible to add samples (documents), is it possible to train a new model on only the new samples and then somehow combine the original model and the new model? How?
Here is what my model looks like.
RCNN(
(embeddings): Embedding(10661, 300)
(lstm): LSTM(300, 64, bidirectional=True)
(dropout): Dropout(p=0.0, inplace=False)
(W): Linear(in_features=428, out_features=64, bias=True)
(tanh): Tanh()
(fc): Linear(in_features=64, out_features=3, bias=True)
(softmax): Softmax(dim=1)
(loss_op): NLLLoss()
)
I am aware of techniques for saving the model and the corresponding load techniques.
torch.save(model.state_dict(), PATH)
torch.save(model, PATH)
torch.save({'epoch': EPOCH, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': LOSS,}, PATH)
I can continue training based on the original samples, but I do not know how to add samples.
If this is something TensorFlow can do by PyTorch cannot, I might switch to TensorFlow.
Upvotes: 0
Views: 840
Reputation: 40708
Assuming you have your model's state saved in some file PATH
, you can load it back in memory with torch.load
. Either on CPU or CUDA device, by default it will be loaded on the device it was on when torch.save
was called).
state_dict = torch.load(PATH)
model.load_state_dict(state_dict)
Assuming model
is an instance of the same nn.Module
class that was used to save the state on PATH
. Now model
will have an identical state (same parameter weights/biases) as when it was saved on PATH
with torch.save
. From there you can call model
and finetune on new data.
Note: You can load it directly on the desired device by passing a torch.device
to torch.load
's map_location
argument.
Upvotes: 1