dkv
dkv

Reputation: 7292

When does PyTorch automatically cast Tensor dtype?

When does PyTorch automatically cast Tensor dtype? Why does it sometimes do it automatically and other times throws and error?

For example this automatically casts c to be a float:

a = torch.tensor(5)    
b = torch.tensor(5.)
c = a*b 

a.dtype
>>> torch.int64

b.dtype
>>> torch.float32

c.dtype
>>> torch.float32

But this throws an error:

a = torch.ones(2, dtype=torch.float)   
b = torch.ones(2, dtype=torch.long)    
c = torch.matmul(a,b)

Traceback (most recent call last):

  File "<ipython-input-128-fbff7a713ff0>", line 1, in <module>
    torch.matmul(a,b)

RuntimeError: Expected object of scalar type Float but got scalar type Long for argument #2 'tensor'

I'm confused since Numpy seems to automatically cast all arrays as necessary e.g.

a = np.ones(2, dtype=np.long)
b = np.ones(2, dtype=np.float)

np.matmul(a,b)
>>> 2.0

a*b
>>> array([1., 1.])

Upvotes: 4

Views: 1968

Answers (1)

Szymon Maszke
Szymon Maszke

Reputation: 24691

It looks like the PyTorch team is working on those types of problems, see this issue. It seems like some basic upcasting is already implemented in 1.0.0 as per your example (probably for the overloaded operators, tried some others like '//' or addition and they work fine), did not find any proof of this though (like github issue or info in documentation). If someone finds it (implicit casting of torch.Tensor for various operations), please post a comment or another answer.

This issue is a proposal on type promotion, as you can see all of those are still open.

Upvotes: 3

Related Questions