ashwani kumar dwivedi
ashwani kumar dwivedi

Reputation: 13

can anyone explain h_n output of gru layer?

I am new to pytorch, have started coding from one month. this is my gru code

hidden_size = 32
gru_layers_count = 2

encoder = nn.GRU(hidden_size, 
                 hidden_size, 
                 num_layers = gru_layers_count, 
                 batch_first = True, bidirectional=True)
ip = torch.randn(64, 100, hidden_size)
op, hn = encoder(ip)
print(op.shape, hn.shape)

the output is:

torch.Size([64, 100, 64]) torch.Size([4, 64, 32])

here I am actually concerned with the shape of hn, its start dimension size is 4 so I am assuming it is 2 gru * 2 directions. however I am a little confused on the arrangement.

So my question is, is it like first 2 are forward and last 2 backward hidden states. or it is alternate forward and backward hidden states ?

is the following is the correct method to extract only forward gru states ?

forward_hidden = hn[[x for x in range(0, gru_layers_count * 2, 2)], :, :]

Upvotes: 1

Views: 24

Answers (1)

Karl
Karl

Reputation: 5508

The shape of the hidden state is (D*num_layers, N, H_out) where D=2 for bidirectional.

The hidden state tensor alternates forward/backward for each layer. The documentation isn't super clear, but this is what they mean by For bidirectional GRUs, forward and backward are directions 0 and 1 respectively.

You can separate the forward/backward states as follows:

hidden_size = 32
gru_layers_count = 2

encoder = nn.GRU(hidden_size, 
                 hidden_size, 
                 num_layers = gru_layers_count, 
                 batch_first = True, bidirectional=True)
ip = torch.randn(64, 100, hidden_size)
op, hn = encoder(ip)

forward_states = hn[0::2]
backward_states = hn[1::2]

Upvotes: 0

Related Questions