Data Fizz
Data Fizz

Reputation: 11

Using pre-trained weights in PyTorch

I am working on implementing a research paper based on computer vision in PyTorch. I have built the model architecture by referring to the paper. The author has uploaded saved weights on GitHub in ".pth.tar" format. I want to put the same weights in my model so that I can skip training and optimization part and directly get output from the neural net.

The paper is Learning to see in the dark.

Model architecture is as follow:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv1d(32, 12, 1)
        .
        .  
   def forward(self, x):
        x = F.relu(self.conv1(x))
        .
        .    
        return x
net = Net()

And it is to be followed by importing trained weight from google drive/cloud storage and defining the function to put the trained weights in the net.

PS: Model architecture is exactly same for both

Upvotes: 1

Views: 1345

Answers (1)

risper
risper

Reputation: 58

If you are using google colab

#mount drive onto google colab

from google.colab import drive
drive.mount('/content/gdrive')

Define the path of the weights

weights_path="/content/gdrive/My Drive/weights.pth"

Extract the tar file

!tar -xvf weights.pth.tar

Load the weights into the model net

net=torch.load(weights_path)

Upvotes: 1

Related Questions