Reputation: 61
When I view the explanation of the function torch.split
in PyTorch, I find it difficult for me to read as a non-English-speaker:
torch.split(tensor, split_size_or_sections, dim=0)
[...]
If
split_size_or_sections
is a list, thentensor
will be split intolen(split_size_or_sections)
chunks with sizes indim
according tosplit_size_or_sections
.
Does "with sizes in dim
" mean "with sizes in split_size_or_sections
along the dimension dim
"?
Upvotes: 1
Views: 1913
Reputation: 24201
Don't worry - your English is fine, that line is a bit confusing.
Yes you're correct. It means if you pass a list e.g. split_size_or_sections=[1,2,4,5]
it will split the tensor into len([1,2,4,5])
chunks (with the splits happening across dim
), and each chunk will be of length 1
, 2
, 4
, 5
respectively.
This implicitly assumes that sum([1,2,4,5])
equals the size of dim
, and will return an error if not.
Upvotes: 1