Reputation: 949
Basically as this thread discusses here, you cannot use python list to wrap your sub-modules (for example your layers); otherwise, Pytorch is not going to update the parameters of the sub-modules inside the list. Instead you should use nn.ModuleList
to wrap your sub-modules to make sure their parameters are going to be updated. Now I have also seen codes like following where the author uses python list to calculate the loss and then do loss.backward()
to do the update (in reinforce algorithm of RL). Here is the code:
policy_loss = []
for log_prob in self.controller.log_probability_slected_action_list:
policy_loss.append(- log_prob * (average_reward - b))
self.optimizer.zero_grad()
final_policy_loss = (torch.cat(policy_loss).sum()) * gamma
final_policy_loss.backward()
self.optimizer.step()
Why using the list in this format works for updating the parameters of modules but the first case does not work? I am very confused now. If I change in the previous code policy_loss = nn.ModuleList([])
, it throws an exception saying that tensor float is not sub-module.
Upvotes: 0
Views: 465
Reputation: 3713
You are misunderstanding what Module
s are. A Module
stores parameters and defines an implementation of the forward pass.
You're allowed to perform arbitrary computation with tensors and parameters resulting in other new tensors. Modules
need not be aware of those tensors. You're also allowed to store lists of tensors in Python lists. When calling backward
it needs to be on a scalar tensor thus the sum of the concatenation. These tensors are losses and not parameters so they should not be attributes of a Module
nor wrapped in a ModuleList
.
Upvotes: 2