Reputation: 115
I have a torch.nn.module class defined in the following way:
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.sub_module_a = .... # nn.module
self.sub_module_b_dict = {
'B': .... # nn.module
}
However after I call torch.nn.DataParallel(MyModule)
and MyModule.to(device)
only sub_module_a
is put on cuda. The 'B' inside self.sub_module_b_dict
is still on CPU.
Looks like DataParallel and to(device) only support first level variables inside a torch.nn.Module class. The modules nested inside a customized structure (in this case, a dictionary) seem to be ignored.
Am I missing some caveats here?
Upvotes: 2
Views: 1232
Reputation: 114866
You MUST use proper nn
containers for all nn.Module
's methods to act recursively on sub modules.
In your case, 'B'
module is stored in a simple pythonic dictionary. Replace this with [nn.ModuleDict
] and you should be fine:
self.sub_module_b_dict = nn.ModuleDict({'B': ...})
See related threads:
a, b, c, d, e, to name just a few...
Upvotes: 2