Blade
Blade

Reputation: 1110

Combining Parameterlist in PyTorch

I am trying to combine two ParameterLists in Pytorch. I've implemented the following snippet:

import torch

list = nn.ParameterList()
for i in sub_list_1:
    list.append(i)
for i in sub_list_2:
    list.append(i)

Is there any functions that takes care of this without a need to loop over each list?

Upvotes: 2

Views: 917

Answers (2)

Blade
Blade

Reputation: 1110

In addition to @jodag's answer, one other solution is to unwrap the lists and construct a new list with them

param_list = nn.ParameterList([*sub_list_1, *sub_list_2])

Upvotes: 0

jodag
jodag

Reputation: 22204

You can use nn.ParameterList.extend, which works like python's built-in list.extend

plist = nn.ParameterList()
plist.extend(sub_list_1)
plist.extend(sub_list_2)

Alternatively, you can use += which is just an alias for extend

plist = nn.ParameterList()
plist += sub_list_1
plist += sub_list_2

Upvotes: 2

Related Questions