Reputation: 1
I trained a model using this github repository. It's a CRNN[10] model and I want to use it now to make predictions. With what I've read, I need to excecute this:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
To do that I need the model class of my network. How do I know what class I have? I tried torchvision.models.crnn()
and torchvision.models.crnn10()
, but both of these don't work. Can anyone tell me how I can load my model?
Upvotes: 0
Views: 125
Reputation: 1398
The Model class is given in the above GitHub repository. To use your trained model weights you can do this.
Clone the above repo
create your test file inside the repo
Add this code to load the model
from model import Model
#create opt class with all these attributes: opt.imgH,opt.imgW,opt.num_fiducial, opt.input_channel,opt.output_channel,opt.hidden_size, opt.num_class,opt.batch_max_length,opt.Transformation,opt.FeatureExtraction,opt.SequenceModeling, opt.Prediction)
model = Model(opt)
model.load_state_dict(torch.load(PATH))
model.eval()
Upvotes: 0