MichaelSB
MichaelSB

Reputation: 3181

How to parallelize RNN function in Pytorch with DataParallel

Here's an RNN model to run character based language generation:

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_layers):
        super(RNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers 

        self.encoder = nn.Embedding(input_size, hidden_size)
        self.GRU = nn.GRU(hidden_size, hidden_size, n_layers, batch_first=True) 
        self.decoder = nn.Linear(hidden_size, output_size)


    def forward(self, input, batch_size):
        self.init_hidden(batch_size)
        input = self.encoder(input)
        output, self.hidden = self.GRU(input, self.hidden) 
        output = self.decoder(output.view(batch_size, self.hidden_size)) 
        return output

    def init_hidden(self, batch_size):
        self.hidden = Variable(torch.randn(self.n_layers, batch_size, self.hidden_size).cuda())

I instantiate the model using DataParallel, to split the batch of inputs across my 4 GPUs:

net = torch.nn.DataParallel(RNN(n_chars, hidden_size, n_chars, n_layers)).cuda()

Here's the full code.

Unfortunately, DataParallel requires the inputs to have batch_size as the first dimension, but GRU function expects hidden tensor to have batch_size as second dimension:

output, self.hidden = self.GRU(input, self.hidden)

The code as is throws the following error (note the printouts showing that encoder is correctly executed on 4 GPUs):

...
forward function: encoding input of shape: (16L, 1L)
forward function: encoding input of shape: (16L, 1L)
forward function: encoding input of shape: (16L,
forward function: encoding input of shape:

forward function: GRU processing input of shape:
1L)
 ( (16L, 16L1L, 1L), 100L)
forward function: GRU processing input of shape:
 (16L, 1L,
forward function: GRU processing input of shape:100L)
 (16L
forward function: GRU processing input of shape:, 1L, 100L) (
16L, 1L, 100L)

Traceback (most recent call last):
  File "gru2.py", line 166, in <module>
    output = net(c, batch_size)
  File "/root/miniconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 206, in __call__
    result = self.forward(*input, **kwargs)
  File "/root/miniconda2/lib/python2.7/site-packages/torch/nn/parallel/data_parallel.py", line 61, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/root/miniconda2/lib/python2.7/site-packages/torch/nn/parallel/data_parallel.py", line 71, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs)
  File "/root/miniconda2/lib/python2.7/site-packages/torch/nn/parallel/parallel_apply.py", line 45, in parallel_apply
    raise output
RuntimeError: Expected hidden size (2, 16L, 100), got (2L, 64L, 100L)

Here the model has 2 layers, batch_size=64, and hidden_size = 100.

How do I parallelize the GRU operation in the forward function?

Upvotes: 3

Views: 3884

Answers (2)

M. Mortazavi
M. Mortazavi

Reputation: 95

PyTorch 1.5 has completely fixed the issues with RNN training and DataParallel. It seems it has done so quite seamlessly. No more gerrymandering being required. I confirmed this today, in a project involving bi-dir GRUS on speech mfccs.


class PEncoder(nn.Module):
    def __init__(self, args, encoder):
        super(PEncoder, self).__init__()
        self.gpu_ids = args.gpu_ids
        self.model = encoder

    def forward(self, input):
        if len(self.gpu_ids) > 1:
            return nn.parallel.data_parallel(self.model, (input), self.gpu_ids)
        else:
            return self.model(input)

It is that simple. This does wrap your model in another model, and effectively produces a slightly different compute graph. So, if you have earlier trained models, you may have to load them in a special manner and create some setters for this parallel wrap. Give it a try, and you'll see. (I've not confirmed this aspect of it.)

Upvotes: 1

dalegebit
dalegebit

Reputation: 41

You can simply set the parameter dim=1, e.g.

net = torch.nn.DataParallel(RNN(n_chars, hidden_size, n_chars, n_layers), dim=1).cuda()

Upvotes: 4

Related Questions