Jacob Quisenberry
Jacob Quisenberry

Reputation: 1199

Add Samples after Partial Training in PyTorch

I have trained a model in PyTorch - an RCNN for text classification. The model has very high precision and recall, but I may eventually receive new documents with text unlike what I used to train, validate, or test the model.

I would like to add new text samples to the model without retraining the model from the beginning. This is desirable because I may lose access to some of the text used for initial training.

If it is not possible to add samples (documents), is it possible to train a new model on only the new samples and then somehow combine the original model and the new model? How?

Here is what my model looks like.

RCNN(
  (embeddings): Embedding(10661, 300)
  (lstm): LSTM(300, 64, bidirectional=True)
  (dropout): Dropout(p=0.0, inplace=False)
  (W): Linear(in_features=428, out_features=64, bias=True)
  (tanh): Tanh()
  (fc): Linear(in_features=64, out_features=3, bias=True)
  (softmax): Softmax(dim=1)
  (loss_op): NLLLoss()
)

I am aware of techniques for saving the model and the corresponding load techniques.

I can continue training based on the original samples, but I do not know how to add samples.

If this is something TensorFlow can do by PyTorch cannot, I might switch to TensorFlow.

Upvotes: 0

Views: 840

Answers (1)

Ivan
Ivan

Reputation: 40708

Assuming you have your model's state saved in some file PATH, you can load it back in memory with torch.load. Either on CPU or CUDA device, by default it will be loaded on the device it was on when torch.save was called).

state_dict = torch.load(PATH)
model.load_state_dict(state_dict)

Assuming model is an instance of the same nn.Module class that was used to save the state on PATH. Now model will have an identical state (same parameter weights/biases) as when it was saved on PATH with torch.save. From there you can call model and finetune on new data.

Note: You can load it directly on the desired device by passing a torch.device to torch.load's map_location argument.

Upvotes: 1

Related Questions