Dhruv Vashist
Dhruv Vashist

Reputation: 151

Use of torch.stack()

t1 = torch.tensor([1,2,3])
t2 = torch.tensor([4,5,6])
t3 = torch.tensor([7,8,9])

torch.stack((t1,t2,t3),dim=1)

When implementing the torch.stack(), I can't understand how stacking is done for different dim. Here stacking is done for columns but I can't understand the details as to how it is done. It becomes more complicated dealing with 2-d or 3-D tensors.

tensor([[1, 4, 7],
        [2, 5, 8],
        [3, 6, 9]])

Upvotes: 10

Views: 9246

Answers (3)

Guangfu WANG
Guangfu WANG

Reputation: 1

Others have post excellent answers regarding this problem. I am confused by this for a long time and figure out one simple (well, may be not simple for some people) way to imagine what would happen if we stack multi dimension tensors into new one.

the trick is that:

  • imagine that you need construct new tensor out of list of old ones.
  • if you need stack certain axes, then alone that axes you see a tensor is same as one of old tensors.

To illustrate this process, consider the tensor the post starter has mentioned, when we stack it along axes=1, then from column view each tensor is [1,2,3]/[4,5,6]/[7,8,9],which is exactly the same with old tensors that we need to stack.

Upvotes: 0

Ivan
Ivan

Reputation: 40618

Imagine have n tensors. If we stay in 3D, those correspond to volumes, namely rectangular cuboids. Stacking corresponds to combining those n volumes on an additional dimension: here a 4th dimension is added to host the n 3D volumes. This operation is in clear contrast with concatenation, where the volumes would be combined on one of the existing dimensions. So concatenation of three-dimensional tensors would result in a 3D tensor.

Here is a possible representation of the stacking operations for limited dimensions sizes (up to three-dimensional inputs):

enter image description here

Where you chose to perform the stacking defines along which new dimension the stack will take place. In the above examples, the newly created dimension is last, hence the idea of "added dimension" makes more sense.

In the following visualization, we observe how tensors can be stacked on different axes. This in turn affects the resulting tensor shape

  • For the 1D case, for instance, it can also happen on the first axis, see below:

    enter image description here

    With code:

    >>> x_1d = list(torch.empty(3, 2))     # 3 lines
    
    >>> torch.stack(x_1d, 0).shape         # axis=0 stacking
    torch.Size([3, 2])
    
    >>> torch.stack(x_1d, 1).shape         # axis=1 stacking
    torch.Size([2, 3])
    
  • Similarly for two-dimensional inputs:

    enter image description here

    With code:

    >>> x_2d = list(torch.empty(3, 2, 2))   # 3 2x2-squares
    
    >>> torch.stack(x_2d, 0).shape          # axis=0 stacking
    torch.Size([3, 2, 2])
    
    >>> torch.stack(x_2d, 1).shape          # axis=1 stacking
    torch.Size([2, 3, 2])
    
    >>> torch.stack(x_2d, 2).shape          # axis=2 stacking
    torch.Size([2, 2, 3])
    

With this state of mind, you can intuitively extend the operation to n-dimensional tensors.

Upvotes: 17

Francesco
Francesco

Reputation: 201

Very simple! I will use 4 varibles for this example. The function torch.stack is very similar to numpy (vstack and hstack). Example :

t1 = torch.tensor([1,2,3])
t2 = torch.tensor([4,5,6])
t3 = torch.tensor([7,8,9])
t4 = torch.tensor([10,11,12])

if you try this cmd

>> torch.stack((t1,t2,t3,t4),dim=1).size()
>> torch.Size([3, 4])

if you change dim=1 with dim=0

>> torch.stack((t1,t2,t3,t4),dim=0).size()
>> torch.Size([4, 3])

In first case you have a tensor with 3x4 dimensional but in last case you have a 4x3 tensor. Test this code without using .size() !

>> torch.stack((t1,t2,t3,t4),dim=1)
>> tensor([[ 1,  4,  7, 10],
        [ 2,  5,  8, 11],
        [ 3,  6,  9, 12]])

>> torch.stack((t1,t2,t3,t4),dim=0)
>> tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])

Happy Coding!

Upvotes: 3

Related Questions