user15032198
user15032198

Reputation: 1

python pytorch use pretrained model

I trained a model using this github repository. It's a CRNN[10] model and I want to use it now to make predictions. With what I've read, I need to excecute this:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

To do that I need the model class of my network. How do I know what class I have? I tried torchvision.models.crnn() and torchvision.models.crnn10(), but both of these don't work. Can anyone tell me how I can load my model?

Upvotes: 0

Views: 125

Answers (1)

darth baba
darth baba

Reputation: 1398

The Model class is given in the above GitHub repository. To use your trained model weights you can do this.

  1. Clone the above repo

  2. create your test file inside the repo

  3. Add this code to load the model

    from model import Model
    #create opt class with all these attributes: opt.imgH,opt.imgW,opt.num_fiducial, opt.input_channel,opt.output_channel,opt.hidden_size, opt.num_class,opt.batch_max_length,opt.Transformation,opt.FeatureExtraction,opt.SequenceModeling, opt.Prediction)
    model = Model(opt)
    model.load_state_dict(torch.load(PATH))
    model.eval()

Upvotes: 0

Related Questions