a stand-out flamingo
a stand-out flamingo

Reputation: 61

Confusion about Pytorch `torch.split` documentation

When I view the explanation of the function torch.split in PyTorch, I find it difficult for me to read as a non-English-speaker:

torch.split(tensor, split_size_or_sections, dim=0)

[...]

If split_size_or_sections is a list, then tensor will be split into len(split_size_or_sections) chunks with sizes in dim according to split_size_or_sections.

Does "with sizes in dim" mean "with sizes in split_size_or_sections along the dimension dim"?

Upvotes: 1

Views: 1913

Answers (1)

iacob
iacob

Reputation: 24201

Don't worry - your English is fine, that line is a bit confusing.

Yes you're correct. It means if you pass a list e.g. split_size_or_sections=[1,2,4,5] it will split the tensor into len([1,2,4,5]) chunks (with the splits happening across dim), and each chunk will be of length 1, 2, 4, 5 respectively.

This implicitly assumes that sum([1,2,4,5]) equals the size of dim, and will return an error if not.

Upvotes: 1

Related Questions