saquibmazhar
saquibmazhar

Reputation: 21

Optimising model.parameters and custom learnable parameter together using torch.optim gives non-leaf tensor error

Framework: PyTorch

I am trying to optimise a custom nn.parameter(Temperature) used in softmax calculation along with the model parameters using a single Adam optimiser while model training. But doing so gives the following error:

ValueError: can't optimize a non-leaf Tensor

Here is my custom loss function:

class CrossEntropyLoss2d(torch.nn.Module):
    def __init__(self, weight=None):
        super().__init__()
        self.temperature = torch.nn.Parameter(torch.ones(1, requires_grad=True, device=device))
        self.loss = torch.nn.NLLLoss(weight)
        self.loss.to(device)
    
    def forward(self, outputs, targets):
        T_logits = self.temp_scale(outputs)
        return self.loss(torch.nn.functional.log_softmax(T_logits, dim=1), targets)

    def temp_scale(self, logits):
        temp = self.temperature.unsqueeze(1).expand(logits.size(1), logits.size(2), logits.size(3))
        return logits/temp 
.
.
.
.
.
.

Here is the part of training code:

criterion = CrossEntropyLoss2d(weight)
params = list(model.parameters()) +list(criterion.temperature)
optimizer = Adam(params, 5e-4, (0.9, 0.999),  eps=1e-08, weight_decay=1e-4)

Error:

File "train_my_net_city.py", line 270, in train
optimizer = Adam(params, 5e-4, (0.9, 0.999),  eps=1e-08, weight_decay=1e-4)
File "/home/saquib/anaconda3/lib/python3.8/site-packages/torch/optim/adam.py", line 48, in __init__
super(Adam, self).__init__(params, defaults)
File "/home/saquib/anaconda3/lib/python3.8/site-packages/torch/optim/optimizer.py", line 54, in __init__
self.add_param_group(param_group)
File "/home/saquib/anaconda3/lib/python3.8/site-packages/torch/optim/optimizer.py", line 257, in add_param_group
raise ValueError("can't optimize a non-leaf Tensor")
ValueError: can't optimize a non-leaf Tensor

Checking the variable for leaf gives true:

print(criterion.temperature.is_leaf)
True

The error arises due to the criterion.temperature parameter and not due to model.parameters.

Upvotes: 1

Views: 1620

Answers (1)

saquibmazhar
saquibmazhar

Reputation: 21

Got it working by doing so:

params = list(model.parameters())
params.append(criterion.temperature)

Upvotes: 1

Related Questions