Reputation: 7292
When does PyTorch automatically cast Tensor dtype? Why does it sometimes do it automatically and other times throws and error?
For example this automatically casts c
to be a float:
a = torch.tensor(5)
b = torch.tensor(5.)
c = a*b
a.dtype
>>> torch.int64
b.dtype
>>> torch.float32
c.dtype
>>> torch.float32
But this throws an error:
a = torch.ones(2, dtype=torch.float)
b = torch.ones(2, dtype=torch.long)
c = torch.matmul(a,b)
Traceback (most recent call last):
File "<ipython-input-128-fbff7a713ff0>", line 1, in <module>
torch.matmul(a,b)
RuntimeError: Expected object of scalar type Float but got scalar type Long for argument #2 'tensor'
I'm confused since Numpy seems to automatically cast all arrays as necessary e.g.
a = np.ones(2, dtype=np.long)
b = np.ones(2, dtype=np.float)
np.matmul(a,b)
>>> 2.0
a*b
>>> array([1., 1.])
Upvotes: 4
Views: 1968
Reputation: 24691
It looks like the PyTorch team is working on those types of problems, see this issue. It seems like some basic upcasting is already implemented in 1.0.0 as per your example (probably for the overloaded operators, tried some others like '//' or addition and they work fine), did not find any proof of this though (like github issue or info in documentation). If someone finds it (implicit casting of torch.Tensor
for various operations), please post a comment or another answer.
This issue is a proposal on type promotion, as you can see all of those are still open.
Upvotes: 3