Reputation: 13
I am new to pytorch, have started coding from one month. this is my gru code
hidden_size = 32
gru_layers_count = 2
encoder = nn.GRU(hidden_size,
hidden_size,
num_layers = gru_layers_count,
batch_first = True, bidirectional=True)
ip = torch.randn(64, 100, hidden_size)
op, hn = encoder(ip)
print(op.shape, hn.shape)
the output is:
torch.Size([64, 100, 64]) torch.Size([4, 64, 32])
here I am actually concerned with the shape of hn
, its start dimension size is 4 so I am assuming it is 2 gru * 2 directions.
however I am a little confused on the arrangement.
So my question is, is it like first 2 are forward and last 2 backward hidden states. or it is alternate forward and backward hidden states ?
is the following is the correct method to extract only forward gru states ?
forward_hidden = hn[[x for x in range(0, gru_layers_count * 2, 2)], :, :]
Upvotes: 1
Views: 24
Reputation: 5508
The shape of the hidden state is (D*num_layers, N, H_out)
where D=2
for bidirectional.
The hidden state tensor alternates forward/backward for each layer. The documentation isn't super clear, but this is what they mean by For bidirectional GRUs, forward and backward are directions 0 and 1 respectively
.
You can separate the forward/backward states as follows:
hidden_size = 32
gru_layers_count = 2
encoder = nn.GRU(hidden_size,
hidden_size,
num_layers = gru_layers_count,
batch_first = True, bidirectional=True)
ip = torch.randn(64, 100, hidden_size)
op, hn = encoder(ip)
forward_states = hn[0::2]
backward_states = hn[1::2]
Upvotes: 0