Reputation: 671
I'm trying to get a list of tensors based on different group,
e.g.,
x = tensor([ 0.3018, -0.0079, 1.4995, -1.4422, 1.6007])
indices = torch.tensor([0,0,1,1,2])
res = func(x,indices)
I want my result to be
res= [[0.3018, -0.0079], [1.4995, -1.4422], [1.6007]]
I'm wondering how can I achieve this result, I checked gather
and index_select
,
but I can't get the result like above.
Thank you!
Upvotes: 2
Views: 207