Jiho Choi
Jiho Choi

Reputation: 1311

Best practice to pass PyTorch device name to model

Currently, I separated train.py with model.py for my deep learning project.

So for the datasets, they are sent to cuda device inside the epoch for loop like below.

train.py

...
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
model = MyNet(~).to(device)
...
for batch_data in train_loader:
    s0 = batch_data[0].to(device)
    s1 = batch_data[1].to(device)
    pred = model(s0, s1)

However, inside my model (in model.py), it also needs to access the device variable for skip connection like method. To make a new copy of hidden unit (for residual connection)

model.py

class MyNet(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super(MyNet, self).__init__()
        self.conv1 = GCNConv(in_feats, hid_feats)
        ...

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x1 = copy.copy(x.float())
        x = self.conv1(x, edge_index)
        skip_conn = torch.zeros(len(data.batch), x1.size(1)).to(device)  # <--
        (some opps for x1 -> skip_conn)
        x = torch.cat((x, skip_conn), 1)

In this case, I am currently passing device as a parameter, however, I believe this is not a best practice.

  1. Where should be the best practice to send the dataset to CUDA?
  2. In the case of multiple scripts need to access device, how sould I handle this? (parameter, global variable?)

Upvotes: 0

Views: 1160

Answers (2)

NickBraunagel
NickBraunagel

Reputation: 1599

I'm not 100% sure this will apply to your case but you can also use .to(device) after the model has been initialized:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class myModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.foo = nn.Linear(100, 100)


model = myModel().to(device)

print(next(model.parameters()).device) # "device(type='cuda', index=0)" if on GPU else "device(type='cpu')"

It's also fine to include the device variable as a parameter within your model class. Here is another implementation option:

class myModel(nn.Module):
    def __init__(self):
        super().__init__()

    @property
    def device(self):
        return next(self.parameters()).device

model = myModel()
print(model.device) # "device(type='cuda', index=0)" if on GPU else "device(type='cpu')"

Upvotes: 1

Dani Cores
Dani Cores

Reputation: 301

You can add a new attribute to MyModel to store the device info and use this in the skip_conn initialization.

class MyNet(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats, device): # <--
    super(MyNet, self).__init__()
    self.conv1 = GCNConv(in_feats, hid_feats)
    self.device = device # <--
    self.to(self.device) # <--
    ...

def forward(self, data):
    x, edge_index = data.x, data.edge_index
    x1 = copy.copy(x.float())
    x = self.conv1(x, edge_index)
    skip_conn = torch.zeros(len(data.batch), x1.size(1), device=self.device)  # <--
    (some opps for x1 -> skip_conn)
    x = torch.cat((x, skip_conn), 1)

Notice that in this example, MyNet is responsible for all the device logic including the .to(device) call. This way, we are encapsulating all model-related device management in the model class itself.

Upvotes: 1

Related Questions