RafazZ
RafazZ

Reputation: 4588

How to specify the axis over which to get the elements in PyTorch

I was looking for an answer around SO, but tbh, don't even know how to phrase a question.

Given a tensor, how do I get indices over a specified axis

Here is a simple example

indices = torch.tensor([2, 4, 5])

# Depending on the context I need either
y = x[indices]

# Or
y = x[:, indices]

# Or any other axis
y = x[:, :, :, indices]

Here is a usecase where I need this behavior


def remove_weaklings(x: Tensor, percentage: float, axis: int) -> Tensor:
    all_axes = set(range(x.ndim)) - set([axis])
    y = x
    # Sumup over all other axes
    for a in all_axes:
        y = y.sum(axis=a, keepdim=True)
    y = y.squeeze()
    # Get the sorted list
    _, idx = torch.sort(y)
    # Get only a fraction of the list
    idx = idx[:int(percentage * len(idx))]
    # Get the indices over some axis
    # !!! This part not sure how to solve !!!
    return x.get_over_axis(axis=axis, indices=idx)  

Upvotes: 0

Views: 282

Answers (1)

Shai
Shai

Reputation: 114866

I think you are looking for topk:

def remove_weaklings(x, percentage, axis):
  k = int(percentage * x.shape[axis])
  return torch.topk(x, k, dim=axis)

If you want a more generic solution, you can use numpy's slicing:

def get_over_axis(x, indices, axis):
  i = []
  for d_ in range(x.dim()):
    if d_ == axis:
      i.append(indices)
    else:
      i.append(np.s_[:])
  return x[tuple(i)]

Upvotes: 1

Related Questions