Reputation: 13
I was having difficulty with my loss getting stuck at a particular value. It would always decrease to a certain value, then stop decreasing. The code regarding the loss was:
criterion = nn.MSELoss()
loss = criterion(y_pred, y_batch.unsqueeze(1))
When I changed it to:
criterion = nn.MSELoss()
loss = criterion(y_pred, target=y_batch)
the issue was fixed.
What was happening before when the target was not specified? Does the target need to be specified for every Pytorch loss function? I found nothing in the documentation about target specifications.
Upvotes: 1
Views: 815
Reputation: 40748
It looks like target
is the name of the second positional argument, that's all. The only difference between the two lines is the unsqueezing of dim=1
on the second one.
Upvotes: 1