yusuf
yusuf

Reputation: 3781

Finding indices with zeros from 4D Tensors in Pytorch

I have an interesting question for you, if you are working with Pytorch.

I have a batch of images with shape (3, 224, 224). So, if my batch size is 64, let’s say, the final tensor I have has the shape of (64, 3, 224, 224).

Now, here is the question. Suppose that some of the images in this batch is filled with only zeros. What is the fastest way to find out which batch indices are with only zeros?

I don’t want to create for loop for that, since it is slow.

Thanks for your answer.

Upvotes: 0

Views: 336

Answers (1)

Qianyi Zhang
Qianyi Zhang

Reputation: 184

a cheaper way to do this is assume only empty image would have sum = 0, which i think is pretty reasonable

import torch
t = torch.rand(64,3,224,224)
t[10] = 0
s = t.view(64, -1).sum(dim = -1)
zero_index = (s==0).nonzero().item() # 10

Upvotes: 1

Related Questions