stillanoob
stillanoob

Reputation: 1379

Split a tensor in torch

Given an input tensor of size n x 2A x B x C, how to split it into two tensors, each of size n x A x B x C? Essentially, n is the batch size.

Upvotes: 2

Views: 5660

Answers (2)

antoleb
antoleb

Reputation: 323

You can use torch.split:

torch.split(input_tensor, split_size_or_sections=A, dim=1)

Upvotes: 1

Manuel Lagunas
Manuel Lagunas

Reputation: 2751

I think you could do something like:

tensor_a = torch.Tensor(n, 2A, B,C)
-- Initialize tensor_a with the data

tensor_b = torch.Tensor(n, A, B, C)
tensor_b = tensor_a[{{},1,{},{}}]
tensor_c = torch.Tensor(n, A, B, C)
tensor_c = tensor_a[{{},2,{},{}}]

Upvotes: 0

Related Questions