Reputation: 43511
I have a PyTorch
LSTM model and my forward
function looks like:
def forward(self, x, hidden):
print('in forward', x.dtype, hidden[0].dtype, hidden[1].dtype)
lstm_out, hidden = self.lstm(x, hidden)
return lstm_out, hidden
All of the print
statements show torch.float64
, which I believe is a double. So then why am I getting this issue?
I've cast to double
in all of the relevant places already.
Upvotes: 3
Views: 4693
Reputation: 25401
Make sure both your data and model are in dtype
double
.
For the model:
net = net.double()
For the data:
net(x.double())
It has been discussed on PyTorch forum.
Upvotes: 6