Reputation: 4588
I was looking for an answer around SO, but tbh, don't even know how to phrase a question.
Given a tensor, how do I get indices over a specified axis
indices = torch.tensor([2, 4, 5])
# Depending on the context I need either
y = x[indices]
# Or
y = x[:, indices]
# Or any other axis
y = x[:, :, :, indices]
def remove_weaklings(x: Tensor, percentage: float, axis: int) -> Tensor:
all_axes = set(range(x.ndim)) - set([axis])
y = x
# Sumup over all other axes
for a in all_axes:
y = y.sum(axis=a, keepdim=True)
y = y.squeeze()
# Get the sorted list
_, idx = torch.sort(y)
# Get only a fraction of the list
idx = idx[:int(percentage * len(idx))]
# Get the indices over some axis
# !!! This part not sure how to solve !!!
return x.get_over_axis(axis=axis, indices=idx)
Upvotes: 0
Views: 282
Reputation: 114866
I think you are looking for topk
:
def remove_weaklings(x, percentage, axis):
k = int(percentage * x.shape[axis])
return torch.topk(x, k, dim=axis)
If you want a more generic solution, you can use numpy
's slicing:
def get_over_axis(x, indices, axis):
i = []
for d_ in range(x.dim()):
if d_ == axis:
i.append(indices)
else:
i.append(np.s_[:])
return x[tuple(i)]
Upvotes: 1