Reputation: 1379
Given an input tensor of size n x 2A x B x C
, how to split it into two tensors, each of size n x A x B x C
? Essentially, n
is the batch size.
Upvotes: 2
Views: 5660
Reputation: 323
You can use torch.split
:
torch.split(input_tensor, split_size_or_sections=A, dim=1)
Upvotes: 1
Reputation: 2751
I think you could do something like:
tensor_a = torch.Tensor(n, 2A, B,C)
-- Initialize tensor_a with the data
tensor_b = torch.Tensor(n, A, B, C)
tensor_b = tensor_a[{{},1,{},{}}]
tensor_c = torch.Tensor(n, A, B, C)
tensor_c = tensor_a[{{},2,{},{}}]
Upvotes: 0