Reputation: 683
This question is very similar to filtering np.nan
values from pytorch in a -Dimensional tensor. The difference is that I want to apply the same concept to tensors of 2 or higher dimensions.
I have a tensor that looks like this:
import torch
tensor = torch.Tensor(
[[1, 1, 1, 1, 1],
[float('nan'), float('nan'), float('nan'), float('nan'), float('nan')],
[2, 2, 2, 2, 2]]
)
>>> tensor.shape
>>> [3, 5]
I would like to find the most pythonic / PyTorch way of to filter out (remove) the rows of the tensor which are nan
. By filtering this tensor
along the first (0
th axis) I want to obtain a filtered_tensor
which looks like this:
>>> print(filtered_tensor)
>>> torch.Tensor(
[[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2]]
)
>>> filtered_tensor.shape
>>> [2, 5]
Upvotes: 7
Views: 17008
Reputation: 2493
Use PyTorch's isnan()
together with any()
to slice tensor
's rows using the obtained boolean mask as follows:
filtered_tensor = tensor[~torch.any(tensor.isnan(),dim=1)]
Note that this will drop any row that has a nan
value in it. If you want to drop only rows where all values are nan
replace torch.any
with torch.all
.
For an N-dimensional tensor you could just flatten all the dims apart from the first dim and apply the same procedure as above:
#Flatten:
shape = tensor.shape
tensor_reshaped = tensor.reshape(shape[0],-1)
#Drop all rows containing any nan:
tensor_reshaped = tensor_reshaped[~torch.any(tensor_reshaped.isnan(),dim=1)]
#Reshape back:
tensor = tensor_reshaped.reshape(tensor_reshaped.shape[0],*shape[1:])
Upvotes: 11