Infintyyy
Infintyyy

Reputation: 949

Understanding when to use python list in Pytorch

Basically as this thread discusses here, you cannot use python list to wrap your sub-modules (for example your layers); otherwise, Pytorch is not going to update the parameters of the sub-modules inside the list. Instead you should use nn.ModuleList to wrap your sub-modules to make sure their parameters are going to be updated. Now I have also seen codes like following where the author uses python list to calculate the loss and then do loss.backward() to do the update (in reinforce algorithm of RL). Here is the code:

 policy_loss = []
    for log_prob in self.controller.log_probability_slected_action_list:
        policy_loss.append(- log_prob * (average_reward - b))
    self.optimizer.zero_grad()
    final_policy_loss = (torch.cat(policy_loss).sum()) * gamma
    final_policy_loss.backward()
    self.optimizer.step()

Why using the list in this format works for updating the parameters of modules but the first case does not work? I am very confused now. If I change in the previous code policy_loss = nn.ModuleList([]), it throws an exception saying that tensor float is not sub-module.

Upvotes: 0

Views: 465

Answers (1)

Manux
Manux

Reputation: 3713

You are misunderstanding what Modules are. A Module stores parameters and defines an implementation of the forward pass.

You're allowed to perform arbitrary computation with tensors and parameters resulting in other new tensors. Modules need not be aware of those tensors. You're also allowed to store lists of tensors in Python lists. When calling backward it needs to be on a scalar tensor thus the sum of the concatenation. These tensors are losses and not parameters so they should not be attributes of a Module nor wrapped in a ModuleList.

Upvotes: 2

Related Questions