puddles
puddles

Reputation: 103

vectorize pytorch tensor indexing

I have a batch of images img_batch, size [8,3,32,32], and I want to manipulate each image by setting randomly selected pixels to zero. I can do this using a for loop over each image but I'm not sure how to vectorize it so I'm not processing only one image at a time. This is my code using loops.

batch_size = 8
prct0 = 0.1
noise = torch.tensor([9, 14, 5, 7, 6, 14, 1, 3])
comb_img = []

for ind in range(batch_size):

    img = img_batch[ind]
    c, h, w = img.shape          
    prct = 1 - (1 - prct0)**noise[ind].item()
    idx = random.sample(range(h*w), int(prct*h*w)  )
    img_noised = img.clone()
    img_noised.view(c,1,-1)[:,0,idx] = 0 
    comb_img.append(img_noised)

comb_img = torch.stack(comb_img) # output is comb_img [8,3,32,32]

I'm new to pytorch and if you see any other improvements, please share.

Upvotes: 0

Views: 2159

Answers (2)

DerekG
DerekG

Reputation: 3938

You can easily do this without a loop in a fully vectorized manner:

  1. Create noise tensor
  2. Select a threshold and round the noise tensor to 0 or 1 based on above or below that threshold (prct0)
  3. Element-wise multiply image tensor by noise tensor

I think calling the vector of power mutlipliers noise is a bit confusing, so I've renamed that vector power_vec in this example: power_vec = noise

# create random noise - note one channel rather than 3 color channels
rand_noise = torch.rand(8,1,32,32)
noise = torch.pow(rand_noise,power_vec) # these tensors are broadcastable


# "round" noise based on threshold  
z = torch.zeros(noise.shape)
o = torch.ones(noise.shape)
noise_rounded = torch.where(noise>prct0,o,z) 

# apply noise mask to each color channel
output = img_batch * noise_rounded.expand(8,3,32,32)    

For simplicity this solution uses your original batch size and image size but could be trivially extended to work on inputs of any image and batch size.

Upvotes: 0

jhso
jhso

Reputation: 3283

First note: Do you need to use noise? It will be a lot easier if you treat all images the same and don't have a different set number of pixels to set to 0.

However, you can do it this way, but you still need a small for loop (in the list comprehension).

#don't want RGB masking, want the whole pixel
rng = torch.rand(*img_batch[:,0:1].shape) 
#create binary mask
mask = torch.stack([rng[i] <= 1-(1-prct0)**noise[i] for i in range(batch_size)]) 
img_batch_masked = img_batch.clone()
#broadcast mask to 3 RGB channels
img_batch_masked[mask.tile([1,3,1,1])] = 0

You can check that the mask is set correctly by summing mask across the last 3 dims, and seeing if it matches your target percentage:

In [5]:     print(mask.sum([1,2,3])/(mask.shape[2] * mask.shape[3]))
tensor([0.6058, 0.7716, 0.4195, 0.5162, 0.4739, 0.7702, 0.1012, 0.2684])

In [6]:     print(1-(1-prct0)**noise)
tensor([0.6126, 0.7712, 0.4095, 0.5217, 0.4686, 0.7712, 0.1000, 0.2710])

Upvotes: 1

Related Questions