Reputation: 718
In order to replace positive values with a certain number and negative ones with another number in a random vector using Numpy
one can do the following:
npy_p = np.random.randn(4,6)
quant = np.where(npy_p>0, c_plus , np.where(npy_p<0, c_minus , npy_p))
However where
method in Pytorch
throws out the following error:
expected scalar type double but found float
Can you help me with that?
Upvotes: 0
Views: 4296
Reputation: 8318
I can't reproduce this error, maybe it will be better if you could share a specific example where it failed (it might be the values you try to fill the tensor with):
import torch
x = torch.rand(4,6)
res = torch.where(x > 0.3,torch.tensor(0.), torch.where(x < 0.1, torch.tensor(-1.), x))
Where x
is and it's of dtype
float32:
tensor([[0.1391, 0.4491, 0.2363, 0.3215, 0.7740, 0.4879],
[0.3051, 0.0870, 0.2869, 0.2575, 0.8825, 0.8201],
[0.4419, 0.1138, 0.0825, 0.9489, 0.1553, 0.6505],
[0.8376, 0.7639, 0.9291, 0.0865, 0.5984, 0.3953]])
And the res
is:
tensor([[ 0.1391, 0.0000, 0.2363, 0.0000, 0.0000, 0.0000],
[ 0.0000, -1.0000, 0.2869, 0.2575, 0.0000, 0.0000],
[ 0.0000, 0.1138, -1.0000, 0.0000, 0.1553, 0.0000],
[ 0.0000, 0.0000, 0.0000, -1.0000, 0.0000, 0.0000]])
The problem is caused because you mix data types in the torch.where
, if you explicitly use the same datatype as the tensor in your constants it works fine.
Upvotes: 3