Shamoon
Shamoon

Reputation: 43511

RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #2

I have a PyTorch LSTM model and my forward function looks like:

    def forward(self, x, hidden):
        print('in forward', x.dtype, hidden[0].dtype, hidden[1].dtype)
        lstm_out, hidden = self.lstm(x, hidden)
        return lstm_out, hidden

All of the print statements show torch.float64, which I believe is a double. So then why am I getting this issue?

I've cast to double in all of the relevant places already.

Upvotes: 3

Views: 4693

Answers (1)

David Ferenczy Rogožan
David Ferenczy Rogožan

Reputation: 25401

Make sure both your data and model are in dtype double.

For the model:

net = net.double()

For the data:

net(x.double())

It has been discussed on PyTorch forum.

Upvotes: 6

Related Questions