Reputation: 2289
I am working on a video frame data where I am getting input data as tensor of the form [batch,channel,frame_sequence,height, weight] (let denote it by [B,C,S,H,W] for clarity. So each batch basically consists of a consecutive sequence of frame. What I basically want to do is run an encoder (consisting of several conv2d) on each frame ie each [C,H,W] and get it back as [B,C_output,S,H_output,W_output]. Now conv2d expects input as (N,C_in,H_in,W_in) form. I am wondering what's the best way to do this without messing up the order within the 5D tensor. So far I am considering following way of thoughts:
>>> # B,C,seq,h,w
# 4,2, 5, 3,3
>>> x = Variable(torch.rand(4,2,5,3,3))
>>> x.size()
#torch.Size([4, 2, 5, 3, 3])
>>> x = x.permute(0,2,1,3,4)
>>> x.size() #expected = 4,5,2,3,3 B,seq,C,h,w
#torch.Size([4, 5, 2, 3, 3])
>>> x = x.contiguous().view(-1,2,3,3)
>>> x.size()
#torch.Size([20, 2, 3, 3])
And then run conv2d (encoder) on the updated x and reshape it. But I think It wouldn't preserve the original order of tensor. So, how can I achieve the goal?
Upvotes: 3
Views: 2468
Reputation: 2125
What you are doing is completely fine. It will preserve the order. You can verify this by visualizing them.
I quickly built this for displaying the images stored in a 4d tensor (where dim=0
is batch) or a 5d tensor (where dim=0
is batch and dim=1
is sequence):
def custom_imshow(tensor):
if tensor.dim() == 4:
count = 1
for i in range(tensor.size(0)):
img = tensor[i].numpy()
plt.subplot(1, tensor.size(0), count)
img = img / 2 + 0.5 # unnormalize
img = np.transpose(img, (1, 2, 0))
count += 1
plt.imshow(img)
plt.axis('off')
if tensor.dim() == 5:
count = 1
for i in range(tensor.size(0)):
for j in range(tensor.size(1)):
img = tensor[i][j].numpy()
plt.subplot(tensor.size(0), tensor.size(1), count)
img = img / 2 + 0.5 # unnormalize
img = np.transpose(img, (1, 2, 0))
plt.imshow(img)
plt.axis('off')
count +=1
Lets say we use the CIFAR-10 dataset (consisting of 32x32x3 size images).
For a tensor x
:
>>> x.size()
torch.Size([4, 5, 3, 32, 32])
>>> custom_imshow(x)
After doing x.view(-1, 3, 32, 32)
:
# x.size() -> torch.Size([4, 5, 3, 32, 32])
>>> x = x.view(-1, 3, 32, 32)
>>> x.size()
torch.Size([20, 3, 32, 32])
>>> custom_imshow(x)
And if you go back to a 5d tensor view:
# x.size() -> torch.Size([20, 3, 32, 32])
>>> x.view(4, 5, 3, 32, 32)
>>> x.size()
torch.Size([4, 5, 3, 32, 32])
>>> custom_imshow(x)
Upvotes: 3