Latent
Latent

Reputation: 596

Saving vocabulary object from pytorch's torchtext library

Building a text classification model using pytorch's torchtext . The vocabulary object is in the data.field :

def create_tabularDataset_object(self,csv_path):
   self.TEXT = data.Field(tokenize=self.tokenizer,batch_first=True,include_lengths=True)
   self.LABEL = data.LabelField(dtype = torch.float,batch_first=True)
def get_vocab_with_glov(self,data):
   # initialize glove embeddings
   self.TEXT.build_vocab(data,min_freq=100,vectors = "glove.6B.100d")

After training , when serving the model in production how do i hold the TEXT object ? at prediction time i need it for indexing the words tokens

[TEXT.vocab.stoi[t] for t in tokenizedׁ_sentence]

am i missing something and it is not necessary to hold that object ? Do i need any other file other then the model weights ?

Upvotes: 3

Views: 6301

Answers (2)

seel
seel

Reputation: 135

Actually the best way (more stable) to do this is to use torch built-in function torch.save(*)

Example to save file:

torch.save(vocab_obj, 'vocab_obj.pth')

To load the file again:

vocab_obj = torch.load('vocab_obj.pth')

Upvotes: 9

Latent
Latent

Reputation: 596

I've found that i can save it as a pkl: Saving the TEXT.vocab as a pkl worked :

def save_vocab(vocab, path):
    import pickle
    output = open(path, 'wb')
    pickle.dump(vocab, output)
    output.close()

Where

vocab = TEXT.vocab 

and reading it as usual.

Upvotes: 2

Related Questions