Reputation: 439
How can I split a tensor by column (axis = 1). For example
"""
input: result:
tensor([[1, 1], (tensor([1, 2, 3, 1, 2, 3]),
[2, 1], tensor([1, 1, 2, 2, 3, 3]))
[3, 2],
[1, 2],
[2, 3],
[3, 3]])
"""
The solution I came out with is first transpose the input tensor, split it and then flatten each of the split tensor. However, is there a simpler and more effective way on doing this? Thank you
import torch
x = torch.LongTensor([[1,1],[2,1],[3,2],[1,2],[2,3],[3,3]])
x1, x2 = torch.split(x.T, 1)
x1 = torch.flatten(x1)
x2 = torch.flatten(x2)
x1, x2 # output
Upvotes: 0
Views: 2940
Reputation: 1950
Simply do:
x1 = x[:, 0]
x2 = x[:, 1]
# x1: (tensor([1, 2, 3, 1, 2, 3]), x2: tensor([1, 1, 2, 2, 3, 3]))
Upvotes: 1