melatonin15
melatonin15

Reputation: 2289

Running conv2d on tensor [batch, channel, sequence, H,W] in Pytorch

I am working on a video frame data where I am getting input data as tensor of the form [batch,channel,frame_sequence,height, weight] (let denote it by [B,C,S,H,W] for clarity. So each batch basically consists of a consecutive sequence of frame. What I basically want to do is run an encoder (consisting of several conv2d) on each frame ie each [C,H,W] and get it back as [B,C_output,S,H_output,W_output]. Now conv2d expects input as (N,C_in,H_in,W_in) form. I am wondering what's the best way to do this without messing up the order within the 5D tensor. So far I am considering following way of thoughts:

>>> # B,C,seq,h,w
# 4,2, 5,  3,3 

>>> x = Variable(torch.rand(4,2,5,3,3))
>>> x.size() 
#torch.Size([4, 2, 5, 3, 3])
>>> x = x.permute(0,2,1,3,4)
>>> x.size() #expected = 4,5,2,3,3 B,seq,C,h,w
#torch.Size([4, 5, 2, 3, 3])
>>> x = x.contiguous().view(-1,2,3,3)
>>> x.size()
#torch.Size([20, 2, 3, 3])

And then run conv2d (encoder) on the updated x and reshape it. But I think It wouldn't preserve the original order of tensor. So, how can I achieve the goal?

Upvotes: 3

Views: 2468

Answers (1)

entrophy
entrophy

Reputation: 2125

What you are doing is completely fine. It will preserve the order. You can verify this by visualizing them.

I quickly built this for displaying the images stored in a 4d tensor (where dim=0 is batch) or a 5d tensor (where dim=0 is batch and dim=1 is sequence):

def custom_imshow(tensor):
    if tensor.dim() == 4:
        count = 1
        for i in range(tensor.size(0)):
            img = tensor[i].numpy()
            plt.subplot(1, tensor.size(0), count)
            img = img / 2 + 0.5     # unnormalize
            img = np.transpose(img, (1, 2, 0))
            count += 1
            plt.imshow(img)
            plt.axis('off')

    if tensor.dim() == 5:
        count = 1
        for i in range(tensor.size(0)):
            for j in range(tensor.size(1)):
                img = tensor[i][j].numpy()
                plt.subplot(tensor.size(0), tensor.size(1), count)
                img = img / 2 + 0.5  # unnormalize
                img = np.transpose(img, (1, 2, 0))

                plt.imshow(img)
                plt.axis('off')
                count +=1

Lets say we use the CIFAR-10 dataset (consisting of 32x32x3 size images).

For a tensor x:

>>> x.size()
torch.Size([4, 5, 3, 32, 32])
>>> custom_imshow(x)

enter image description here

After doing x.view(-1, 3, 32, 32):

 # x.size() -> torch.Size([4, 5, 3, 32, 32])
 >>> x = x.view(-1, 3, 32, 32)
 >>> x.size()
 torch.Size([20, 3, 32, 32])
 >>> custom_imshow(x)

enter image description here

And if you go back to a 5d tensor view:

# x.size() -> torch.Size([20, 3, 32, 32])
>>> x.view(4, 5, 3, 32, 32)
>>> x.size()
torch.Size([4, 5, 3, 32, 32])
>>> custom_imshow(x)

enter image description here

Upvotes: 3

Related Questions