Reputation: 1249
Given any general float torch.Tensor
, possibly containing some NaN values, I am looking for an efficient method to either replace all the NaN values in it with zero, or remove them altogether and filter out the "useful" values in another new Tensor.
I am aware that a trivial way to do this is to manually iterate through all the values in the given tensor (and correspondingly replace them with zero or reject them for the new tensor).
Is there some pre-defined Torch function or a combination of functions which can achieve this more efficiently in terms of performance, which relies on the inherent CPU-GPU optimisations of Torch?
Upvotes: 2
Views: 4485
Reputation: 2160
Well, it looks like there is no function in torch
checking tensor for NaNs. But since NaN != NaN, there's a work around:
a = torch.rand(4, 5)
a[2][3] = tonumber('nan')
nan_mask = a:ne(a)
notnan_mask = a:eq(a)
print(a)
0.2434 0.1731 0.3440 0.3340 0.0519
0.0932 0.4067 nan 0.1827 0.5945
0.3020 0.1035 0.5415 0.3329 0.7881
0.6108 0.9498 0.0406 0.9335 0.3582
[torch.DoubleTensor of size 4x5]
print(nan_mask)
0 0 0 0 0
0 0 1 0 0
0 0 0 0 0
0 0 0 0 0
[torch.ByteTensor of size 4x5]
Having these masks, you can efficiently extract NaN/not NaN values and replace them with whatever you want:
print(a[notnan_mask])
...
[torch.DoubleTensor of size 19]
a[nan_mask] = 42
print(a)
0.2434 0.1731 0.3440 0.3340 0.0519
0.0932 0.4067 42.0000 0.1827 0.5945
0.3020 0.1035 0.5415 0.3329 0.7881
0.6108 0.9498 0.0406 0.9335 0.3582
[torch.DoubleTensor of size 4x5]
Upvotes: 4