Reputation: 245
I'm trying to index the maximum elements along the last dimension in a multidimensional tensor. For example, say I have a tensor
A = torch.randn((5, 2, 3))
_, idx = torch.max(A, dim=2)
Here idx stores the maximum indices, which may look something like
>>>> A
tensor([[[ 1.0503, 0.4448, 1.8663],
[ 0.8627, 0.0685, 1.4241]],
[[ 1.2924, 0.2456, 0.1764],
[ 1.3777, 0.9401, 1.4637]],
[[ 0.5235, 0.4550, 0.2476],
[ 0.7823, 0.3004, 0.7792]],
[[ 1.9384, 0.3291, 0.7914],
[ 0.5211, 0.1320, 0.6330]],
[[ 0.3292, 0.9086, 0.0078],
[ 1.3612, 0.0610, 0.4023]]])
>>>> idx
tensor([[ 2, 2],
[ 0, 2],
[ 0, 0],
[ 0, 2],
[ 1, 0]])
I want to be able to access these indices and assign to another tensor based on them. Meaning I want to be able to do
B = torch.new_zeros(A.size())
B[idx] = A[idx]
where B is 0 everywhere except where A is maximum along the last dimension. That is B should store
>>>>B
tensor([[[ 0, 0, 1.8663],
[ 0, 0, 1.4241]],
[[ 1.2924, 0, 0],
[ 0, 0, 1.4637]],
[[ 0.5235, 0, 0],
[ 0.7823, 0, 0]],
[[ 1.9384, 0, 0],
[ 0, 0, 0.6330]],
[[ 0, 0.9086, 0],
[ 1.3612, 0, 0]]])
This is proving to be much more difficult than I expected, as the idx does not index the array A properly. Thus far I have been unable to find a vectorized solution to use idx to index A.
Is there a good vectorized way to do this?
Upvotes: 8
Views: 7012
Reputation: 1
could use torch.scatter here
>>> import torch
>>> a = torch.randn(4,2,3)
>>> a
tensor([[[ 0.1583, 0.1102, -0.8188],
[ 0.6328, -1.9169, -0.5596]],
[[ 0.5335, 0.4069, 0.8403],
[-1.2537, 0.9868, -0.4947]],
[[-1.2830, 0.4386, -0.0107],
[ 1.3384, 0.5651, 0.2877]],
[[-0.0334, -1.0619, -0.1144],
[ 0.1954, -0.7371, 1.7001]]])
>>> ind = torch.max(a,1,keepdims=True)[1]
>>> ind
tensor([[[1, 0, 1]],
[[0, 1, 0]],
[[1, 1, 1]],
[[1, 1, 1]]])
>>> torch.zeros_like(a).scatter(1,ind,a)
tensor([[[ 0.0000, 0.1102, 0.0000],
[ 0.1583, 0.0000, -0.8188]],
[[ 0.5335, 0.0000, 0.8403],
[ 0.0000, 0.4069, 0.0000]],
[[ 0.0000, 0.0000, 0.0000],
[-1.2830, 0.4386, -0.0107]],
[[ 0.0000, 0.0000, 0.0000],
[-0.0334, -1.0619, -0.1144]]])
Upvotes: 0
Reputation: 36249
You can use torch.meshgrid
to create an index tuple:
>>> index_tuple = torch.meshgrid([torch.arange(x) for x in A.size()[:-1]]) + (idx,)
>>> B = torch.zeros_like(A)
>>> B[index_tuple] = A[index_tuple]
Note that you can also mimic meshgrid
via (for the specific case of 3D):
>>> index_tuple = (
... torch.arange(A.size(0))[:, None],
... torch.arange(A.size(1))[None, :],
... idx
... )
Bit more explanation:
We will have the indices something like this:
In [173]: idx
Out[173]:
tensor([[2, 1],
[2, 0],
[2, 1],
[2, 2],
[2, 2]])
From this, we want to go to three indices (since our tensor is 3D, we need three numbers to retrieve each element). Basically we want to build a grid in the first two dimensions, as shown below. (And that's why we use meshgrid).
In [174]: A[0, 0, 2], A[0, 1, 1]
Out[174]: (tensor(0.6288), tensor(-0.3070))
In [175]: A[1, 0, 2], A[1, 1, 0]
Out[175]: (tensor(1.7085), tensor(0.7818))
In [176]: A[2, 0, 2], A[2, 1, 1]
Out[176]: (tensor(0.4823), tensor(1.1199))
In [177]: A[3, 0, 2], A[3, 1, 2]
Out[177]: (tensor(1.6903), tensor(1.0800))
In [178]: A[4, 0, 2], A[4, 1, 2]
Out[178]: (tensor(0.9138), tensor(0.1779))
In the above 5 lines, the first two numbers in the indices are basically the grid that we build using meshgrid and the third number is coming from idx
.
i.e. the first two numbers form a grid.
(0, 0) (0, 1)
(1, 0) (1, 1)
(2, 0) (2, 1)
(3, 0) (3, 1)
(4, 0) (4, 1)
Upvotes: 5
Reputation: 13103
An ugly hackaround is to create a binary mask out of idx
and use it to index the arrays. The basic code looks like this:
import torch
torch.manual_seed(0)
A = torch.randn((5, 2, 3))
_, idx = torch.max(A, dim=2)
mask = torch.arange(A.size(2)).reshape(1, 1, -1) == idx.unsqueeze(2)
B = torch.zeros_like(A)
B[mask] = A[mask]
print(A)
print(B)
The trick is that torch.arange(A.size(2))
enumerates the possible values in idx
and mask
is nonzero in places where they equal the idx
. Remarks:
torch.max
, you can use torch.argmax
instead.torch.nn.functional.max_pool3d
with kernel of size (1, 1, 3)
.torch.where
as shown here.I would expect that somebody comes up with a cleaner solution (avoiding the intermedia allocation of the mask
array), likely making use of torch.index_select
, but I can't get it to work right now.
Upvotes: 3