Ha An Tran
Ha An Tran

Reputation: 357

How to wrap a torch.jit model inside a torch Module?

I'm trying to call a TorchScript model inside a torch.nn.Module but got an error related to pickle.

Here's the code to reproduce:

import torch
import torch.nn as nn

# A simple base model to create a ScriptModel
class ExampleModel(nn.Module):
    def __init__(self, factor: int):
        super(ExampleModel, self).__init__()
        self.factor = factor

    def forward(self, x):
        return x * self.factor

# Define a wrapper model with a ModuleDict
class WrapperModel(nn.Module):
    def __init__(self, path):
        super(WrapperModel, self).__init__()
        self.model = torch.jit.load(path)


    def forward(self, name: str, x):
        return self.model(x)


scripted_model = torch.jit.script(ExampleModel(2))
scripted_model.save("model.jit")


# Initialize the WrapperModel
wrapper = WrapperModel("model.jit")

And when I try to pickle wrapper with:

import pickle
pickle.dumps(wrapper)

I got error:

RuntimeError: Tried to serialize object __torch__.___torch_mangle_3.ExampleModel which does not have a __getstate__ method defined!

Is there a way to call TorchScript model so that it doesn't raise such error?

Upvotes: 0

Views: 23

Answers (1)

Aryan Raj
Aryan Raj

Reputation: 248

TorchScript models aren't meant to be pickled directly. Think of them like a special frozen version of your model that can't be saved the usual way. Here's a simple fix - just wrap the TorchScript model in a module

class WrapperModel(nn.Module):
    def __init__(self, path):
        super(WrapperModel, self).__init__()
        self.model = torch.jit.load(path)
        # Tell PyTorch to ignore this when saving
        self._model = self.model
        del self.model

    def forward(self, x):
        return self._model(x)


By using self._model (with the underscore), we're telling PyTorch "hey, don't try to save this part". This way, you can pickle the wrapper without issues

Upvotes: 0

Related Questions