Reputation: 2470
import torch
import models
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
model = models.__dict__['resnet18']
model = torch.nn.DataParallel(model,device_ids = [0]) #PROBLEM CAUSING LINE
model.to('cuda:0')
To run this code you need to clone this repository : https://github.com/SoftwareGift/FeatherNets_Face-Anti-spoofing-Attack-Detection-Challenge-CVPR2019.git
Please run this piece of code inside the root folder of the cloned directory.
I am getting the follow error AttributeError: 'function' object has no attribute 'cuda'
I have tried using torch.device object as well for the same function and it results in the same error.
Please ask for any other details that are required. PyTorch newbie here
python:3.7 pytorch:1.3.1
Upvotes: 2
Views: 9002
Reputation: 13113
Replace
model = torch.nn.DataParallel(model,device_ids = [0])
with
model = torch.nn.DataParallel(model(), device_ids=[0])
(notice the ()
after model inside DataParallel
). The difference is simple: the models
module contains classes/functions which create models and not instances of models. If you trace the imports, you'll find that models.__dict__['resnet18']
resolves to this function. Since DataParallel
wraps an instance, not a class itself, it is incompatible. The ()
calls this model building function/class constructor to create an instance of this model.
A much simpler example of this would be the following
class MyNet(nn.Model):
def __init__(self):
self.linear = nn.Linear(4, 4)
def forward(self, x):
return self.linear(x)
model = nn.DataParallel(MyNet) # this is what you're doing
model = nn.DataParallel(MyNet()) # this is what you should be doing
Your error message complains that function
(since model
without the ()
is of type function
) has no attribute cuda
, which is a method of nn.Model
instances.
Upvotes: 2