Reputation: 1311
Currently, I separated train.py
with model.py
for my deep learning project.
So for the datasets, they are sent to cuda device inside the epoch for loop
like below.
train.py
...
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
model = MyNet(~).to(device)
...
for batch_data in train_loader:
s0 = batch_data[0].to(device)
s1 = batch_data[1].to(device)
pred = model(s0, s1)
However, inside my model (in model.py), it also needs to access the device variable for skip connection like method. To make a new copy of hidden unit (for residual connection)
model.py
class MyNet(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats):
super(MyNet, self).__init__()
self.conv1 = GCNConv(in_feats, hid_feats)
...
def forward(self, data):
x, edge_index = data.x, data.edge_index
x1 = copy.copy(x.float())
x = self.conv1(x, edge_index)
skip_conn = torch.zeros(len(data.batch), x1.size(1)).to(device) # <--
(some opps for x1 -> skip_conn)
x = torch.cat((x, skip_conn), 1)
In this case, I am currently passing device
as a parameter, however, I believe this is not a best practice.
device
, how sould I handle this? (parameter, global variable?)Upvotes: 0
Views: 1160
Reputation: 1599
I'm not 100% sure this will apply to your case but you can also use .to(device)
after the model has been initialized:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class myModel(nn.Module):
def __init__(self):
super().__init__()
self.foo = nn.Linear(100, 100)
model = myModel().to(device)
print(next(model.parameters()).device) # "device(type='cuda', index=0)" if on GPU else "device(type='cpu')"
It's also fine to include the device
variable as a parameter within your model class. Here is another implementation option:
class myModel(nn.Module):
def __init__(self):
super().__init__()
@property
def device(self):
return next(self.parameters()).device
model = myModel()
print(model.device) # "device(type='cuda', index=0)" if on GPU else "device(type='cpu')"
Upvotes: 1
Reputation: 301
You can add a new attribute to MyModel
to store the device
info and use this in the skip_conn
initialization.
class MyNet(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats, device): # <--
super(MyNet, self).__init__()
self.conv1 = GCNConv(in_feats, hid_feats)
self.device = device # <--
self.to(self.device) # <--
...
def forward(self, data):
x, edge_index = data.x, data.edge_index
x1 = copy.copy(x.float())
x = self.conv1(x, edge_index)
skip_conn = torch.zeros(len(data.batch), x1.size(1), device=self.device) # <--
(some opps for x1 -> skip_conn)
x = torch.cat((x, skip_conn), 1)
Notice that in this example, MyNet
is responsible for all the device logic including the .to(device)
call. This way, we are encapsulating all model-related device management in the model class itself.
Upvotes: 1