Mohamed Ibrahim
Mohamed Ibrahim

Reputation: 291

issue in loading Model using PyTorch in google-collaboratory

I am trying to Load the Model in google_collaboratory to get evaluate it and generate all the statistics results.

My trying

import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.backends.cudnn as cudnn
import numpy as np
import torch.nn as nn
import os

def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = fc_model.Network(checkpoint['input_size'],
                             checkpoint['output_size'],
                             checkpoint['hidden_layers'])
    model.load_state_dict(checkpoint['state_dict'])
    
    return model

PATH = "/content/gdrive/MyDrive/best.pt"
state_dict = load_checkpoint(PATH)

The Error

---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
<ipython-input-24-0515f2edfa1a> in <module>()
     18 
     19 PATH = "/content/gdrive/MyDrive/best.pt"
---> 20 state_dict = load_checkpoint(PATH)

2 frames
/usr/local/lib/python3.7/dist-packages/torch/serialization.py in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)
    849     unpickler = pickle_module.Unpickler(data_file, **pickle_load_args)
    850     unpickler.persistent_load = persistent_load
--> 851     result = unpickler.load()
    852 
    853     torch._utils._validate_loaded_sparse_tensors()

ModuleNotFoundError: No module named 'models'

---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
"Open Examples" button below.
---------------------------------------------------------------------------

I tried to install some library but it gives me the same is there anyways to load models inside google collaboratory.

Upvotes: 1

Views: 185

Answers (1)

Natthaphon Hongcharoen
Natthaphon Hongcharoen

Reputation: 2430

This problem is that when you save the weight you actually uses torch.save(model instead of model.state_dict()

One way to solve this is import the models "the same way you did when train". This is important as when you save the whole model it save the name reference along with the weight.

Maybe you'll need to upload it if models is a file. If it's an object then just put it to a cell and it'll work.

Upvotes: 1

Related Questions