Jan Pisl
Jan Pisl

Reputation: 1193

Either too little or too many arguments for a nn.Sequential

I am new to PyTorch, so please excuse my silly question.

I define a nn.Sequential in init of my Encoder object like this:

self.list_of_blocks = [EncoderBlock(n_features, n_heads, n_hidden, dropout) for _ in range(n_blocks)]
self.blocks = nn.Sequential(*self.list_of_blocks)

The forward of EncoderBlock looks like this

def forward(self, x, mask):

In the forward() of my Encoder, I try to do:

z0 = self.blocks(z0, mask)

I expect the nn.Sequential to pass these two arguments to individual blocks.

However, I get

TypeError: forward() takes 2 positional arguments but 3 were given

When I try:

z0 = self.blocks(z0)

I get (understandably):

TypeError: forward() takes 2 positional arguments but only 1 was given

When I do not use nn.Sequential and just execute one EncoderBlock after another, it works:

for i in range(self.n_blocks):
     z0 = self.list_of_blocks[i](z0, mask)

Question: What am I doing wrong and how do I use nn.Sequential correctly in this case?

Upvotes: 1

Views: 1491

Answers (1)

Jindřich
Jindřich

Reputation: 11240

Sequential in general does not work with multiple inputs and outputs.

It is an often discussed topic, see PyTorch forum and GitHub issues #1908 or #9979.

You can define your own version of sequential. Assuming the mask is the same for all your encoder block (e.g., like in the Transformer networks), you can do:

class MaskedSequential(nn.Sequential):
    def forward(self, x, mask):
        for module in self._modules.values():
            x = module(x, mask)
        return inputs

Or if your EncoderBlocks return tuples, you can use a more general solution suggested in one of the GitHub issues:

class MySequential(nn.Sequential):
    def forward(self, *inputs):
        for module in self._modules.values():
            if type(inputs) == tuple:
                inputs = module(*inputs)
            else:
                inputs = module(inputs)
        return inputs

Upvotes: 5

Related Questions