skytree
skytree

Reputation: 1100

How to change parts of parameters' device type of a module in pytorch?

I define a net and two parameters are on CPU, I try to move those two parameters to GPU. However, when I print the device, I find that those two are not moved to GPU. How to change the model parameter device?

for p in net.parameters():
    if p.device == torch.device('cpu'):
        p = p.to('cuda')


for p in net.parameters():
    if p.device == torch.device('cpu'):
        print(p.device)

Output:

cpu
cpu

Upvotes: 1

Views: 2232

Answers (1)

Berriel
Berriel

Reputation: 13601

You're dealing with parameters. Unlike a Module, you have to attribute them back to the original variable if you want to replace them. Additionally, you'll want to change the .data of a given parameter, otherwise it won't work because the .to(...) actually generates a copy.

for p in net.parameters():
    if p.device == torch.device('cpu'):
        p.data = p.to('cuda')

Note that if any of these parameters have .grad, they will not be moved to the GPU. Take a look here at how parameters are usually moved. As you'll see, you'll have to do the same for the gradients:

for p in net.parameters():
    if p.device == torch.device('cpu'):
        p.data = p.to('cuda')
        if p.grad is not None:
            p.grad.data = p.grad.to('cuda')

Upvotes: 3

Related Questions