Tanmay Bhatnagar
Tanmay Bhatnagar

Reputation: 2470

Pytorch : AttributeError: 'function' object has no attribute 'cuda'

import torch
import models
model_names = sorted(name for name in models.__dict__
                     if name.islower() and not name.startswith("__")
                     and callable(models.__dict__[name]))
model = models.__dict__['resnet18']
model = torch.nn.DataParallel(model,device_ids = [0])  #PROBLEM CAUSING LINE
model.to('cuda:0')

To run this code you need to clone this repository : https://github.com/SoftwareGift/FeatherNets_Face-Anti-spoofing-Attack-Detection-Challenge-CVPR2019.git

Please run this piece of code inside the root folder of the cloned directory.

I am getting the follow error AttributeError: 'function' object has no attribute 'cuda' I have tried using torch.device object as well for the same function and it results in the same error. Please ask for any other details that are required. PyTorch newbie here python:3.7 pytorch:1.3.1

Upvotes: 2

Views: 9002

Answers (1)

Jatentaki
Jatentaki

Reputation: 13113

Replace

model = torch.nn.DataParallel(model,device_ids = [0])

with

model = torch.nn.DataParallel(model(), device_ids=[0])

(notice the () after model inside DataParallel). The difference is simple: the models module contains classes/functions which create models and not instances of models. If you trace the imports, you'll find that models.__dict__['resnet18'] resolves to this function. Since DataParallel wraps an instance, not a class itself, it is incompatible. The () calls this model building function/class constructor to create an instance of this model.

A much simpler example of this would be the following

class MyNet(nn.Model):
    def __init__(self):
        self.linear = nn.Linear(4, 4)
    def forward(self, x):
        return self.linear(x)

model = nn.DataParallel(MyNet) # this is what you're doing
model = nn.DataParallel(MyNet()) # this is what you should be doing

Your error message complains that function (since model without the () is of type function) has no attribute cuda, which is a method of nn.Model instances.

Upvotes: 2

Related Questions