Saeed
Saeed

Reputation: 718

Pytorch: why does torch.where method does not work like numpy.where?

In order to replace positive values with a certain number and negative ones with another number in a random vector using Numpy one can do the following:

npy_p = np.random.randn(4,6)
quant = np.where(npy_p>0, c_plus , np.where(npy_p<0, c_minus , npy_p))

However where method in Pytorch throws out the following error:

expected scalar type double but found float

Can you help me with that?

Upvotes: 0

Views: 4296

Answers (1)

David
David

Reputation: 8318

I can't reproduce this error, maybe it will be better if you could share a specific example where it failed (it might be the values you try to fill the tensor with):

import torch
x = torch.rand(4,6)
res = torch.where(x > 0.3,torch.tensor(0.), torch.where(x < 0.1, torch.tensor(-1.), x))

Where x is and it's of dtype float32:

tensor([[0.1391, 0.4491, 0.2363, 0.3215, 0.7740, 0.4879],
        [0.3051, 0.0870, 0.2869, 0.2575, 0.8825, 0.8201],
        [0.4419, 0.1138, 0.0825, 0.9489, 0.1553, 0.6505],
        [0.8376, 0.7639, 0.9291, 0.0865, 0.5984, 0.3953]])

And the res is:

tensor([[ 0.1391,  0.0000,  0.2363,  0.0000,  0.0000,  0.0000],
        [ 0.0000, -1.0000,  0.2869,  0.2575,  0.0000,  0.0000],
        [ 0.0000,  0.1138, -1.0000,  0.0000,  0.1553,  0.0000],
        [ 0.0000,  0.0000,  0.0000, -1.0000,  0.0000,  0.0000]])

The problem is caused because you mix data types in the torch.where, if you explicitly use the same datatype as the tensor in your constants it works fine.

Upvotes: 3

Related Questions