qhu
qhu

Reputation: 1

PyTorch JIT script error when Sequential container takes a Tuple input

PyTorch JIT script error when Sequential container takes a Tuple input.PyTorch

This is a simple net to reproduce my error. I’m passing a Tuple to the forward method and have specified the typing. I think the error is caused by Jit inferring the input type to Sequential’s forward method to be a Tensor, and not a Tuple. How can I fix this error?

class MyBatchNorm(nn.Module):
    def __init__(self, output_size, d_ids):
        super().__init__()
        self.d_ids = d_ids
        self.net = nn.ModuleDict({f"{d}": nn.BatchNorm1d(output_size) for d in d_ids})
    
    def forward(self, input_tuple: Tuple[torch.Tensor, int]) -> Tuple[torch.Tensor, int]:
        input_tensor, d = input_tuple
        output_tensor = torch.tensor([])
        for d_name, d_norm in self.net.items():
            if f"{d}" == d_name:
                output_tensor = d_norm(input_tensor)
        if len(output_tensor) == 0:
            raise ValueError(f"invalid d {d}, must be {self.d_ids}")
        return output_tensor, d

class MyNet(nn.Module):
    def __init__(self, output_size, d_ids):
        super().__init__()
        dense_layers = [
            MyBatchNorm(output_size, d_ids),
            MyBatchNorm(output_size, d_ids)
        ]
        self.net = torch.nn.Sequential(*dense_layers)
        
    def forward(self, input_tensor: torch.Tensor, d_tensor: torch.Tensor) -> torch.Tensor:
        d = d_tensor.squeeze()[0].item()
        output_tensor, _ = self.net((input_tensor, d))
        return torch.squeeze(output_tensor)

Error:

RuntimeError: 

forward(__torch__.___torch_mangle_16.MyBatchNorm self, (Tensor, int) input_tuple) -> ((Tensor, int)):
Expected a value of type 'Tuple[Tensor, int]' for argument 'input_tuple' but instead found type 'Tensor (inferred)'.
Inferred the value for argument 'input_tuple' to be of type 'Tensor' because it was not annotated with an explicit type.
:
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/torch/nn/modules/container.py", line 117
    def forward(self, input):
        for module in self:
            input = module(input)
                    ~~~~~~ <--- HERE
        return input

Upvotes: 0

Views: 655

Answers (1)

mintermine
mintermine

Reputation: 85

This is not a solution, but this is a workaround for this particular issue. Instead of using nn.Sequential, use nn.ModuleList.

Let's start with a module that requires an iterable input:

class DoubleMe(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
        return x * 2, y * 2, z * 2

The following implementation of nn.Sequential works with PyTorch's default eager mode, but not when converting the model using torchscript:

class QuadrupleMe(nn.Module):
    def __init__(self):
        super().__init__()

        self.sequence = nn.Sequential(DoubleMe(), DoubleMe())

    def forward(self, x: list[torch.Tensor]):
        return self.sequence(x)

The above class gives this issue raised in this question, which can be replicated with the following:

quad_me = QuadrupleMe()
script = torch.jit.script(quad_me)

with error

forward(__torch__.DoubleMe self, Tensor x, Tensor y, Tensor z) -> ((Tensor, Tensor, Tensor)):
Argument y not provided.
:
  File "/home/username/anaconda3/envs/pixels/lib/python3.11/site-packages/torch/nn/modules/container.py", line 215
    def forward(self, input):
        for module in self:
            input = module(input)
                    ~~~~~~ <--- HERE
        return input

Now, let's convert QuadrupleMe to a scriptable module with nn.ModuleList:

class QuadrupleMe(nn.Module):
    def __init__(self):
        super().__init__()

        self.modules = nn.ModuleList([DoubleMe(), DoubleMe()])

    def forward(self, x: list[torch.Tensor]):
        for module in self.modules:
            x = module(x)
        return x

Note that nn.Sequential is swapped with nn.ModuleList in the __init__ function. Further, nn.ModuleList requires an iterable as input, so the input is converted to a list.

The converted class mimics nn.Sequential's forward method, found here.

Upvotes: 0

Related Questions