Reputation: 25
I have 4 graphic cards which I want to utilize to pytorch. I have this net:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
How can I use them on this net?
Upvotes: 1
Views: 147
Reputation: 24691
You may use torch.nn.DataParallel to distribute your model among many workers.
Just pass your network (torch.nn.Module
) to it's constructor and use forward as you would normally. You may also specify on which GPUs it is supposed to run by providing device_ids
with List[int]
or torch.device
.
Just for the sake of code:
import torch
# Your network
network = Net()
torch.nn.DataParallel(network)
# Use however you wish
network.forward(data)
Upvotes: 1