Reputation: 174
I am trying to train a pytorch model. The loss function is:
cn_loss = torch.nn.CrossEntropyLoss(weight=train_label_weight, reduction='mean')
Code fragment from the training function:
for sents, targets in batch_iter(df_train, batch_size=train_batch_size, shuffle=True, bert=bert_size):
train_iter += 1
optimizer.zero_grad()
batch_size = len(sents)
pre_softmax = model(sents)
float_targets=torch.tensor(targets, dtype=torch.float, device=device)
loss = cn_loss(pre_softmax, float_targets)
loss.backward()
optimizer.step()
Data types of both pre_softmax
and float_targets
is torch.float32
. (In the original code data type of the targets
has been converted to torch.int64
using torch.tensor(targets, dtype=torch.long, device=device)
. However as I get the
RuntimeError: expected scalar type Float but found Double
error I converted the data type of targets
to torch.float32
)
Even though both parameters in cn_loss()
function are torch.float32
, I get the below error when I run the code:
loss = cn_loss(pre_softmax, float_targets) File "C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\nn\modules\loss.py", line 1152, in forward label_smoothing=self.label_smoothing) File "C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\nn\functional.py", line 2846, in cross_entropy return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) RuntimeError: expected scalar type Float but found Double
I checked the data types multiple times and changed the data type of float_targets
using torch.FloatTensor(targets)
also. But I get the same error.
Upvotes: 0
Views: 1575
Reputation: 174
As pointed out by @aretor in comments, the data type of train_label_weight
was torch.float64
. When converted it to torch.float32
and changed targets to torch.long
back again, the code worked perfectly
Upvotes: 3