nr spider
nr spider

Reputation: 174

RuntimeError: expected scalar type Float but found Double error torch.nn.CrossEntropyLoss Pytorch

I am trying to train a pytorch model. The loss function is:

cn_loss = torch.nn.CrossEntropyLoss(weight=train_label_weight, reduction='mean')

Code fragment from the training function:

for sents, targets in batch_iter(df_train, batch_size=train_batch_size, shuffle=True, bert=bert_size):
        train_iter += 1
        optimizer.zero_grad()
        batch_size = len(sents)
        pre_softmax = model(sents)
        float_targets=torch.tensor(targets, dtype=torch.float, device=device)

        loss = cn_loss(pre_softmax, float_targets)
        
        loss.backward()
        optimizer.step()

Data types of both pre_softmax and float_targets is torch.float32. (In the original code data type of the targets has been converted to torch.int64 using torch.tensor(targets, dtype=torch.long, device=device). However as I get the

RuntimeError: expected scalar type Float but found Double

error I converted the data type of targets to torch.float32)

Even though both parameters in cn_loss() function are torch.float32, I get the below error when I run the code:

loss = cn_loss(pre_softmax, float_targets) File "C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\nn\modules\loss.py", line 1152, in forward label_smoothing=self.label_smoothing) File "C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\nn\functional.py", line 2846, in cross_entropy return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) RuntimeError: expected scalar type Float but found Double

I checked the data types multiple times and changed the data type of float_targets using torch.FloatTensor(targets) also. But I get the same error.

Upvotes: 0

Views: 1575

Answers (1)

nr spider
nr spider

Reputation: 174

As pointed out by @aretor in comments, the data type of train_label_weight was torch.float64. When converted it to torch.float32 and changed targets to torch.long back again, the code worked perfectly

Upvotes: 3

Related Questions