jxmorris12
jxmorris12

Reputation: 1372

Stacking tensors in a list of tuples of tensors

I have a list of tuples of PyTorch tensors. It looks like this:

[
    (tensor([1, 2, 3]),  tensor([4, 5, 6, 7]),  tensor([8])),
    (tensor([9, 10,11]), tensor([11,12,13,14]), tensor([15])),
    (tensor([16,17,18]), tensor([19,20,21,22]), tensor([23])),
    ...
]

Tensors in each column (that is, tensors that position k of their respective tuple) share the same shape. I want to stack the tensors in each column so that I end up with a single tuple, each value being the tensors concatenated along the dimension of the column.

In this case, the output tuple would have three values, and look like this:

(
 tensor([[1,2,3], [9,10,11], [16,17,18]]),

 tensor([[4,5,6,7], [11,12,13,14], [19,20,21,22]],

 tensor([[8],[15],[23])
)

This is a made-up example. I want to do this for tuples of any length, and tensors of arbitrary size. What is the best way to do this type of concatenation quickly using PyTorch?

Upvotes: 2

Views: 7531

Answers (1)

jxmorris12
jxmorris12

Reputation: 1372

If anyone gets themselves into the same convoluted scenario, I was able to solve it with a lovely one-liner:

tuple(map(torch.stack, zip(*x)))

In this case, x is the original list I mentioned above. This line of code transforms x into the exact desired format.

Upvotes: 3

Related Questions