RNN
RNN

Reputation: 25

How can I run pytorch with multiple graphic cards?

I have 4 graphic cards which I want to utilize to pytorch. I have this net:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

How can I use them on this net?

Upvotes: 1

Views: 147

Answers (1)

Szymon Maszke
Szymon Maszke

Reputation: 24691

You may use torch.nn.DataParallel to distribute your model among many workers.

Just pass your network (torch.nn.Module) to it's constructor and use forward as you would normally. You may also specify on which GPUs it is supposed to run by providing device_ids with List[int] or torch.device.

Just for the sake of code:

import torch

# Your network
network = Net()
torch.nn.DataParallel(network)

# Use however you wish
network.forward(data)

Upvotes: 1

Related Questions