Reputation: 1
PyTorch JIT script error when Sequential container takes a Tuple input.PyTorch
This is a simple net to reproduce my error. I’m passing a Tuple to the forward method and have specified the typing. I think the error is caused by Jit inferring the input type to Sequential’s forward method to be a Tensor, and not a Tuple. How can I fix this error?
class MyBatchNorm(nn.Module):
def __init__(self, output_size, d_ids):
super().__init__()
self.d_ids = d_ids
self.net = nn.ModuleDict({f"{d}": nn.BatchNorm1d(output_size) for d in d_ids})
def forward(self, input_tuple: Tuple[torch.Tensor, int]) -> Tuple[torch.Tensor, int]:
input_tensor, d = input_tuple
output_tensor = torch.tensor([])
for d_name, d_norm in self.net.items():
if f"{d}" == d_name:
output_tensor = d_norm(input_tensor)
if len(output_tensor) == 0:
raise ValueError(f"invalid d {d}, must be {self.d_ids}")
return output_tensor, d
class MyNet(nn.Module):
def __init__(self, output_size, d_ids):
super().__init__()
dense_layers = [
MyBatchNorm(output_size, d_ids),
MyBatchNorm(output_size, d_ids)
]
self.net = torch.nn.Sequential(*dense_layers)
def forward(self, input_tensor: torch.Tensor, d_tensor: torch.Tensor) -> torch.Tensor:
d = d_tensor.squeeze()[0].item()
output_tensor, _ = self.net((input_tensor, d))
return torch.squeeze(output_tensor)
Error:
RuntimeError:
forward(__torch__.___torch_mangle_16.MyBatchNorm self, (Tensor, int) input_tuple) -> ((Tensor, int)):
Expected a value of type 'Tuple[Tensor, int]' for argument 'input_tuple' but instead found type 'Tensor (inferred)'.
Inferred the value for argument 'input_tuple' to be of type 'Tensor' because it was not annotated with an explicit type.
:
File "/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/torch/nn/modules/container.py", line 117
def forward(self, input):
for module in self:
input = module(input)
~~~~~~ <--- HERE
return input
Upvotes: 0
Views: 655
Reputation: 85
This is not a solution, but this is a workaround for this particular issue. Instead of using nn.Sequential
, use nn.ModuleList
.
Let's start with a module that requires an iterable input:
class DoubleMe(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return x * 2, y * 2, z * 2
The following implementation of nn.Sequential
works with PyTorch's default eager mode, but not when converting the model using torchscript
:
class QuadrupleMe(nn.Module):
def __init__(self):
super().__init__()
self.sequence = nn.Sequential(DoubleMe(), DoubleMe())
def forward(self, x: list[torch.Tensor]):
return self.sequence(x)
The above class gives this issue raised in this question, which can be replicated with the following:
quad_me = QuadrupleMe()
script = torch.jit.script(quad_me)
with error
forward(__torch__.DoubleMe self, Tensor x, Tensor y, Tensor z) -> ((Tensor, Tensor, Tensor)):
Argument y not provided.
:
File "/home/username/anaconda3/envs/pixels/lib/python3.11/site-packages/torch/nn/modules/container.py", line 215
def forward(self, input):
for module in self:
input = module(input)
~~~~~~ <--- HERE
return input
Now, let's convert QuadrupleMe
to a scriptable module with nn.ModuleList
:
class QuadrupleMe(nn.Module):
def __init__(self):
super().__init__()
self.modules = nn.ModuleList([DoubleMe(), DoubleMe()])
def forward(self, x: list[torch.Tensor]):
for module in self.modules:
x = module(x)
return x
Note that nn.Sequential
is swapped with nn.ModuleList
in the __init__
function. Further, nn.ModuleList
requires an iterable as input, so the input is converted to a list.
The converted class mimics nn.Sequential
's forward
method, found here.
Upvotes: 0