bindas
bindas

Reputation: 53

Converting 4 dimensional tensors into list of lists of lists (Python)

I have 6 tensors of shape (batch_size, S, S, 1) and I want to combine them in one python list of size (batch_size, S*S, 6) - so every element of tensor should be inside the inner list.

Can this be achieved without using loops? What's the efficient way to solve it?

Upvotes: 0

Views: 155

Answers (1)

Ivan
Ivan

Reputation: 40628

Let batch_size=10 and S=4 for the purpose of this example:

 >>> x = [torch.rand(10, 4, 4, 1) for _ in range(6)]

Indeed the first step is to concatenate the tensor on the last dimension axis=3:

>>> y = torch.cat(x, -1)
>>> y.shape
torch.Size([10, 4, 4, 6])

Then reshape to flatten axis=1 and axis=2, you can do so with torch.flatten here since the two axes as adjacent:

>>> y = torch.cat(x, -1).flatten(1, 2)
>>> y.shape
torch.Size([10, 16, 6])

Upvotes: 1

Related Questions