Reputation: 177
So I have a tensor
A = torch.tensor([[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]])
and I was hoping to rearrange it in sliding windows, that is,
f(A) = torch.tensor([[1,2,5,6],[3,4,7,8],[9,10,13,14],[11,12,15,16]])
by viewing every tensor in A
as a 2-by-2 small window, and then arranging the "window"s in a 2-by-2 matrix. This operation is similar to ArrayFlatten
of Mathematica, but I couldn' find a way to do it in PyTorch. Any help is welcome.
Upvotes: 1
Views: 281
Reputation: 24201
I can't think of a neat way to do this off the top of my head, but you can achieve it with a few choice slices and concatenations:
A_ = A.reshape(4,2,2)
torch.cat([torch.cat([*A_[:2]],1), torch.cat([*A_[2:]],1)],0)
Alternative approaches:
A.unfold(1,2,2).unfold(0,2,2).reshape(*A.shape).index_select(1, torch.LongTensor([0,2,1,3]))
torch.cat([*A_],1).reshape(*A.shape).index_select(0, torch.LongTensor([0,2,1,3]))
Upvotes: 1