Daxi Song
Daxi Song

Reputation: 67

About using RNN in pytorch

I am trying to use RNN to do a binary classification. But when my model is training, it gets stuck at loss.backward(). Here is my model:

class RNN2(nn.Module):
    def __init__(self, input_size, hidden_size, output_size=2, num_layers=1):
        super(RNN2, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, num_layers)
        self.reg = nn.Linear(hidden_size, output_size)
        #self.softmax = nn.LogSoftmax(dim=1)

    def forward(self,x):
        x, hidden = self.rnn(x)
        return self.reg(x[:,2])

rnn = RNN2(13,10)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)
for e in range(10):
    out = rnn(train_X)
    optimizer.zero_grad()
    print(out[0])
    print(out.shape)
    print(train_Y.shape)
    loss = criterion(out, train_Y)
    print(loss)
    loss.backward()
    print("1")
    optimizer.step()
    print("2")

The shape of train_X is 420000*3*13 and the shape of train_Y is 420000 So it can print loss. Can anyone tell me why it gets stuck at loss.backward(). It can't print 1.

Upvotes: 0

Views: 157

Answers (1)

Mohammad Arvan
Mohammad Arvan

Reputation: 623

You have to know that in RRNs, computing the backward function for a sequence of length 420000 is extremely slow. If you run your code on a machine with a GPU (or google colab) and add the following lines before the for loop, your code finishes executing in less than two minutes.

rnn = rnn.cuda()
train_X = train_X.cuda()
train_Y = train_Y.cuda()

Note that by default, the second input dimension passed to RNN will be treated as the batch size. Therefore, if the 420000 is the number of batches, pass batch_first=True to the RNN constructor.

self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)

This would significantly speed up the process (less than one second in google colab). However, if that is not the case, you should try chunking the sequences into smaller parts and increasing the batch size from 3 to a larger value.

Upvotes: 1

Related Questions