Caleb Wan
Caleb Wan

Reputation: 23

Manually adjusting parameters of a torch.nn.Module

Suppose I had a simple neural network defined by:

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2,2)
        self.fc2 = nn.Linear(2,2)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x
    
net = Net()

Doing the following:

for param in net.parameters():
    print(param.data)

results in something like:

tensor([[-0.0776,  0.2409],
        [ 0.3478, -0.6820]])
tensor([-0.6311,  0.2323])
tensor([[-0.5466,  0.0341],
        [ 0.5822,  0.7005]])
tensor([-0.5624,  0.3278])

Let's say that I have a tensor([[0,0],[0,1]]) and I wanted to replace the first param.data with my custom tensor.

Is this possible and if so, how can I do this?

Upvotes: 0

Views: 1457

Answers (1)

Theodor Peifer
Theodor Peifer

Reputation: 3506

Should be possible, even though I dont know a reason why you would want to do this haha, anyways, this should be it:


replace_with = tensor([[0,0],[0,1]])

for parameter in net.parameters():
    parameter.data = replace_with
    break

Now the first element of the parameters should be your custom tensor. I hope that solves your issue :)

Upvotes: 2

Related Questions