Kyle Ong
Kyle Ong

Reputation: 439

Pytorch: split a tensor by column

How can I split a tensor by column (axis = 1). For example

"""
input:              result:
tensor([[1, 1],     (tensor([1, 2, 3, 1, 2, 3]), 
        [2, 1],      tensor([1, 1, 2, 2, 3, 3]))
        [3, 2],  
        [1, 2],
        [2, 3],
        [3, 3]])
"""

The solution I came out with is first transpose the input tensor, split it and then flatten each of the split tensor. However, is there a simpler and more effective way on doing this? Thank you

import torch
x = torch.LongTensor([[1,1],[2,1],[3,2],[1,2],[2,3],[3,3]])
x1, x2 = torch.split(x.T, 1)
x1 = torch.flatten(x1)
x2 = torch.flatten(x2)
x1, x2 # output

Upvotes: 0

Views: 2940

Answers (1)

Alex Metsai
Alex Metsai

Reputation: 1950

Simply do:

x1 = x[:, 0]
x2 = x[:, 1]
# x1: (tensor([1, 2, 3, 1, 2, 3]), x2: tensor([1, 1, 2, 2, 3, 3]))

Upvotes: 1

Related Questions