the-bass
the-bass

Reputation: 745

How do I add LSTM, GRU or other recurrent layers to a Sequential in PyTorch

I like using torch.nn.Sequential as in

self.conv_layer = torch.nn.Sequential(
    torch.nn.Conv1d(196, 196, kernel_size=15, stride=4),
    torch.nn.Dropout()
)

But when I want to add a recurrent layer such as torch.nn.GRU it won't work because the output of recurrent layers in PyTorch is a tuple and you need to choose which part of the output you want to further process.

So is there any way to get

self.rec_layer = nn.Sequential(
    torch.nn.GRU(input_size=2, hidden_size=256),
    torch.nn.Linear(in_features=256, out_features=1)
)

to work? For this example, let's say I want to feed torch.nn.GRU(input_size=2, hidden_size=20)(x)[1][-1] (the last hidden state of the last layer) into the following Linear layer.

Upvotes: 2

Views: 2859

Answers (1)

user3098048
user3098048

Reputation: 151

I made a module called SelectItem to pick out an element from a tuple or list

class SelectItem(nn.Module):
    def __init__(self, item_index):
        super(SelectItem, self).__init__()
        self._name = 'selectitem'
        self.item_index = item_index

    def forward(self, inputs):
        return inputs[self.item_index]

SelectItem can be used in Sequential to pick out the hidden state:

    net = nn.Sequential(
        nn.GRU(dim_in, dim_out, batch_first=True),
        SelectItem(1)
        )

Upvotes: 6

Related Questions