Reputation: 11
Say I have an array of values w = [w1, w2, w3, ...., wn]
and this array is sorted in ascending order, all values being equally spaced.
I have a pytorch tensor of any arbitrary shape. For the sake of this example, lets say that tensor is:
import torch
a = torch.rand(2,4)
assuming w1=torch.min(a)
and wn=torch.max(a)
, I want to create two separate tensors, amax
and amin
, both of shape (2,4)
such that amax
contains values from w
that are the nearest maximum value to the elements of a
, and vice-versa for amin
.
As an example, say:
a = tensor([[0.7192, 0.6264, 0.5180, 0.8836],
[0.1067, 0.1216, 0.6250, 0.7356]])
w = [0.0, 0.33, 0.66, 1]
therefore, I would like amax
and amin
to be,
amax = tensor([[1.000, 0.66, 0.66, 1.000],
[0.33, 0.33, 0.66, 1.00]])
amin = tensor([[0.66, 0.33, 0.33, 0.66],
[0.00, 0.00, 0.33, 0.66]])
What is the fastest way to do this?
Upvotes: 0
Views: 1154
Reputation: 40658
You could compute for all points in a
, the difference with each bin inside w
. For this, you need a little bit of broadcasting:
>>> z = a[...,None]-w[None,None]
tensor([[[ 0.7192, 0.3892, 0.0592, -0.2808],
[ 0.6264, 0.2964, -0.0336, -0.3736],
[ 0.5180, 0.1880, -0.1420, -0.4820],
[ 0.8836, 0.5536, 0.2236, -0.1164]],
[[ 0.1067, -0.2233, -0.5533, -0.8933],
[ 0.1216, -0.2084, -0.5384, -0.8784],
[ 0.6250, 0.2950, -0.0350, -0.3750],
[ 0.7356, 0.4056, 0.0756, -0.2644]]])
We have to identify for each point in a
, at which index (visually represented as columns here) the sign change occurs. We can apply the sign
operator, then compute difference z[i+1]-z[i]
between columns with diff
, retrieve the non zero values with nonzero
, then finally select and reshape the resulting tensor:
>>> index = z.sign().diff(dim=-1).nonzero()[:,2].view(2,4)
tensor([[2, 1, 1, 2],
[0, 0, 1, 2]])
To get amin
, simply index w
with index
:
>>> w[index]
tensor([[0.6600, 0.3300, 0.3300, 0.6600],
[0.0000, 0.0000, 0.3300, 0.6600]])
And to get amax
, we can offset the indices to jump to the upper bound:
>>> w[index+1]
tensor([[1.0000, 0.6600, 0.6600, 1.0000],
[0.3300, 0.3300, 0.6600, 1.0000]])
Upvotes: 1