Reputation: 1193
I am new to PyTorch, so please excuse my silly question.
I define a nn.Sequential in init of my Encoder object like this:
self.list_of_blocks = [EncoderBlock(n_features, n_heads, n_hidden, dropout) for _ in range(n_blocks)]
self.blocks = nn.Sequential(*self.list_of_blocks)
The forward of EncoderBlock looks like this
def forward(self, x, mask):
In the forward() of my Encoder, I try to do:
z0 = self.blocks(z0, mask)
I expect the nn.Sequential to pass these two arguments to individual blocks.
However, I get
TypeError: forward() takes 2 positional arguments but 3 were given
When I try:
z0 = self.blocks(z0)
I get (understandably):
TypeError: forward() takes 2 positional arguments but only 1 was given
When I do not use nn.Sequential and just execute one EncoderBlock after another, it works:
for i in range(self.n_blocks):
z0 = self.list_of_blocks[i](z0, mask)
Question: What am I doing wrong and how do I use nn.Sequential correctly in this case?
Upvotes: 1
Views: 1491
Reputation: 11240
Sequential in general does not work with multiple inputs and outputs.
It is an often discussed topic, see PyTorch forum and GitHub issues #1908 or #9979.
You can define your own version of sequential. Assuming the mask is the same for all your encoder block (e.g., like in the Transformer networks), you can do:
class MaskedSequential(nn.Sequential):
def forward(self, x, mask):
for module in self._modules.values():
x = module(x, mask)
return inputs
Or if your EncoderBlock
s return tuples, you can use a more general solution suggested in one of the GitHub issues:
class MySequential(nn.Sequential):
def forward(self, *inputs):
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
Upvotes: 5