whtitefall
whtitefall

Reputation: 671

Get a list of tensor from masked indices

I'm trying to get a list of tensors based on different group,

e.g.,

x = tensor([ 0.3018, -0.0079,  1.4995, -1.4422,  1.6007])

indices = torch.tensor([0,0,1,1,2])

res = func(x,indices)

I want my result to be

res= [[0.3018, -0.0079], [1.4995, -1.4422], [1.6007]]

I'm wondering how can I achieve this result, I checked gather and index_select, but I can't get the result like above.

Thank you!

Upvotes: 2

Views: 207

Answers (1)

Shai
Shai

Reputation: 114876

How about

res = [x[indices == i_] for i_ in indices.unique()]

Upvotes: 3

Related Questions