user2517984
user2517984

Reputation: 115

torch.nn.DataParallel and to(device) does not support nested modules

I have a torch.nn.module class defined in the following way:

class MyModule(torch.nn.Module):
    def __init__(self):
       super(MyModule, self).__init__()
       self.sub_module_a = ....  # nn.module
       self.sub_module_b_dict = {
          'B': .... # nn.module
       }  

However after I call torch.nn.DataParallel(MyModule) and MyModule.to(device) only sub_module_a is put on cuda. The 'B' inside self.sub_module_b_dict is still on CPU.

Looks like DataParallel and to(device) only support first level variables inside a torch.nn.Module class. The modules nested inside a customized structure (in this case, a dictionary) seem to be ignored.

Am I missing some caveats here?

Upvotes: 2

Views: 1232

Answers (1)

Shai
Shai

Reputation: 114866

You MUST use proper nn containers for all nn.Module's methods to act recursively on sub modules.

In your case, 'B' module is stored in a simple pythonic dictionary. Replace this with [nn.ModuleDict] and you should be fine:

self.sub_module_b_dict = nn.ModuleDict({'B': ...})

See related threads:
a, b, c, d, e, to name just a few...

Upvotes: 2

Related Questions