t-smart
t-smart

Reputation: 177

Rearranging PyTorch tensor in a windowed manner

So I have a tensor

A = torch.tensor([[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]])

and I was hoping to rearrange it in sliding windows, that is,

f(A) = torch.tensor([[1,2,5,6],[3,4,7,8],[9,10,13,14],[11,12,15,16]])

by viewing every tensor in A as a 2-by-2 small window, and then arranging the "window"s in a 2-by-2 matrix. This operation is similar to ArrayFlatten of Mathematica, but I couldn' find a way to do it in PyTorch. Any help is welcome.

Upvotes: 1

Views: 281

Answers (1)

iacob
iacob

Reputation: 24201

I can't think of a neat way to do this off the top of my head, but you can achieve it with a few choice slices and concatenations:

A_ = A.reshape(4,2,2)
torch.cat([torch.cat([*A_[:2]],1), torch.cat([*A_[2:]],1)],0)

Alternative approaches:

A.unfold(1,2,2).unfold(0,2,2).reshape(*A.shape).index_select(1, torch.LongTensor([0,2,1,3]))
torch.cat([*A_],1).reshape(*A.shape).index_select(0, torch.LongTensor([0,2,1,3]))

Upvotes: 1

Related Questions