Robin Sharma
Robin Sharma

Reputation: 191

Multi-Layer Bidirectional LSTM/GRU merge modes in PyTorch

I am trying to replicate my code from Keras into PyTorch to compare the performance of multi-layer bidirectional LSTM/GRU models on CPUs and GPUs. I would like to look into different merge modes such as 'concat' (which is the default mode in PyTorch), sum, mul, average. Merge mode defines how the output from the forward and backward direction will be passed on to the next layer.

In Keras, it's just an argument change for the merge mode for a multi-layer bidirectional LSTM/GRU models, does something similar exist in PyTorch as well? One option is to do the merge mode operation manually after every layer and pass to next layer, but I want to study the performance, so I want to know if there is any other efficient way.

Thanks,

Upvotes: 2

Views: 591

Answers (1)

Erik
Erik

Reputation: 74

To the best of my knowledge, there is no more efficient than implementing it yourself in PyTorch, i.e., there exists no simple argument option.

As you said, the standard mode is tensorflow's 'concat'. If we wanna verify this, we can test it as follows:

import torch
from torch import nn

# Create the LSTMs
in_dim = 5
out_dim = 100
lstm = nn.LSTM(in_dim, out_dim, batch_first=True)
bilstm = nn.LSTM(in_dim, out_dim, batch_first=True, bidirectional=True)

# Copy forward weights
bilstm.weight_ih_l0 = lstm.weight_ih_l0
bilstm.weight_hh_l0 = lstm.weight_hh_l0
bilstm.bias_ih_l0 = lstm.bias_ih_l0
bilstm.bias_hh_l0 = lstm.bias_hh_l0

# Execute on random example
x = torch.randn(1, 3, in_dim)
output1, (h_n1, c_n1) = lstm(x)
output2, (h_n2, c_n2) = bilstm(x)

# Assert equality of the forward loops
assert torch.allclose(output1, output2[:, :, :out_dim])  # Output is the same
assert torch.allclose(h_n1, h_n2[0])  # Hidden state is the same
assert torch.allclose(c_n1, c_n2[0])  # Cell state is the same

In the following examples for the other three merge modes for future reference:

Initialization

in_dim = 5
out_dim = 100
bilstm = nn.LSTM(in_dim, out_dim, batch_first=True, bidirectional=True)
x = torch.randn(1, 3, in_dim)
output, (h_n, c_n) = bilstm(x)

Sum ('sum')

# Merge Mode: 'sum'

# Simple version
output_sum = output[:, :, :out_dim] + output[:, :, out_dim:]
assert output_sum.shape == (1, 3, out_dim)

# Faster version
output_sum2 = torch.sum(output.view(x.size(0), x.size(1), 2, -1), dim=2)
assert torch.allclose(output_sum, output_sum2)

On my machine, the "faster version" needs approx. half the time of the simple version.

Multiplication ('mul')

# Merge Mode: 'mul'
output_mul = output[:, :, :out_dim] * output[:, :, out_dim:]
assert output_mul.shape == (1, 3, out_dim)
# Faster version
output_mul2 = torch.prod(output.view(x.size(0), x.size(1), 2, -1), dim=2)
assert torch.allclose(output_mul, output_mul2)

Average ('ave')

# Merge Mode: 'ave'

# Simple version
output_ave = (output[:, :, :out_dim] + output[:, :, out_dim:]) / 2
assert output_ave.shape == (1, 3, out_dim)

# Faster version
output_ave2 = torch.mean(output.view(x.size(0), x.size(1), 2, -1), dim=2)
assert torch.allclose(output_ave, output_ave2)

Again, the faster version takes approx. 50% of the time of the simple version on my device.

I hope this helps people finding this in the future. :)

Upvotes: 0

Related Questions