dyno8426
dyno8426

Reputation: 1249

torch7: Filtering out NaN values

Given any general float torch.Tensor, possibly containing some NaN values, I am looking for an efficient method to either replace all the NaN values in it with zero, or remove them altogether and filter out the "useful" values in another new Tensor.

I am aware that a trivial way to do this is to manually iterate through all the values in the given tensor (and correspondingly replace them with zero or reject them for the new tensor).

Is there some pre-defined Torch function or a combination of functions which can achieve this more efficiently in terms of performance, which relies on the inherent CPU-GPU optimisations of Torch?

Upvotes: 2

Views: 4485

Answers (1)

Alexander Lutsenko
Alexander Lutsenko

Reputation: 2160

Well, it looks like there is no function in torch checking tensor for NaNs. But since NaN != NaN, there's a work around:

a = torch.rand(4, 5)
a[2][3] = tonumber('nan')
nan_mask = a:ne(a)
notnan_mask = a:eq(a)

print(a)
 0.2434  0.1731  0.3440  0.3340  0.0519
 0.0932  0.4067  nan     0.1827  0.5945
 0.3020  0.1035  0.5415  0.3329  0.7881
 0.6108  0.9498  0.0406  0.9335  0.3582
[torch.DoubleTensor of size 4x5]

print(nan_mask)
 0  0  0  0  0
 0  0  1  0  0
 0  0  0  0  0
 0  0  0  0  0
[torch.ByteTensor of size 4x5]

Having these masks, you can efficiently extract NaN/not NaN values and replace them with whatever you want:

print(a[notnan_mask])
...
[torch.DoubleTensor of size 19]

a[nan_mask] = 42
print(a)
  0.2434   0.1731   0.3440   0.3340   0.0519
  0.0932   0.4067  42.0000   0.1827   0.5945
  0.3020   0.1035   0.5415   0.3329   0.7881
  0.6108   0.9498   0.0406   0.9335   0.3582
[torch.DoubleTensor of size 4x5]

Upvotes: 4

Related Questions