D Liebman
D Liebman

Reputation: 347

Pytorch save embeddings as part of encoder class or not

So I'm using pytorch for the first time. I'm trying to save weights to a file. I'm using a Encoder class that has a GRU and a embedding component. I want to make sure when I save the Encoder values that I will get the embedding values. Initially my code uses state_dict() to copy values to a dictionary of my own which I pass to torch.save(). Should I be looking for some way to save this embedding component or is it part of the larger encoder? The Encoder is a subclass of nn.Module . here's a link:

http://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html#sphx-glr-intermediate-seq2seq-translation-tutorial-py

def make_state(self, converted=False):
    if not converted:
        z = [
            {
                'epoch':0,
                'arch': None,
                'state_dict': self.model_1.state_dict(),
                'best_prec1': None,
                'optimizer': self.opt_1.state_dict(),
                'best_loss': self.best_loss
            },
            {
                'epoch':0,
                'arch':None,
                'state_dict':self.model_2.state_dict(),
                'best_prec1':None,
                'optimizer': self.opt_2.state_dict(),
                'best_loss': self.best_loss
            }
        ]
    else:
        z = [
            {
                'epoch': 0,
                'arch': None,
                'state_dict': self.model_1.state_dict(),
                'best_prec1': None,
                'optimizer': None , # self.opt_1.state_dict(),
                'best_loss': self.best_loss
            },
            {
                'epoch': 0,
                'arch': None,
                'state_dict': self.model_2.state_dict(),
                'best_prec1': None,
                'optimizer': None, # self.opt_2.state_dict(),
                'best_loss': self.best_loss
            }
        ]
    #print(z)
    return z
    pass

def save_checkpoint(self, state=None, is_best=True, num=0, converted=False):
    if state is None:
        state = self.make_state(converted=converted)
        if converted: print(converted, 'is converted.')
    basename = hparams['save_dir'] + hparams['base_filename']
    torch.save(state, basename + '.' + str(num)+ '.pth.tar')
    if is_best:
        os.system('cp '+ basename + '.' + str(num) + '.pth.tar' + ' '  +
                  basename + '.best.pth.tar')

https://discuss.pytorch.org/t/saving-and-loading-a-model-in-pytorch/2610/3

Here is another link

Upvotes: 1

Views: 3828

Answers (1)

layog
layog

Reputation: 4801

No, you do not need to save the embedding values explicitly. Saving a model’s state_dict will save all the variables pertaining to that model, including the embedding weights.
You can look for what a state dict contains by looping over it as -

for var_name in model.state_dict():
    print(var_name)

Upvotes: 1

Related Questions