Reputation: 1372
I have a list of tuples of PyTorch tensors. It looks like this:
[
(tensor([1, 2, 3]), tensor([4, 5, 6, 7]), tensor([8])),
(tensor([9, 10,11]), tensor([11,12,13,14]), tensor([15])),
(tensor([16,17,18]), tensor([19,20,21,22]), tensor([23])),
...
]
Tensors in each column (that is, tensors that position k of their respective tuple) share the same shape. I want to stack the tensors in each column so that I end up with a single tuple, each value being the tensors concatenated along the dimension of the column.
In this case, the output tuple would have three values, and look like this:
(
tensor([[1,2,3], [9,10,11], [16,17,18]]),
tensor([[4,5,6,7], [11,12,13,14], [19,20,21,22]],
tensor([[8],[15],[23])
)
This is a made-up example. I want to do this for tuples of any length, and tensors of arbitrary size. What is the best way to do this type of concatenation quickly using PyTorch?
Upvotes: 2
Views: 7531
Reputation: 1372
If anyone gets themselves into the same convoluted scenario, I was able to solve it with a lovely one-liner:
tuple(map(torch.stack, zip(*x)))
In this case, x
is the original list I mentioned above. This line of code transforms x
into the exact desired format.
Upvotes: 3