a_guest
a_guest

Reputation: 36299

In Numpy how to select elements based on the maximum of their absolute values?

Basically I'd like to do the following Python equivalent for Numpy arrays of arbitrary dimension and specifying an arbitrary axis:

max(array, key=abs)

i.e. select elements based on the maximum absolute value (similar to how array.max(axis=axis) selects just the maximum value along a specific axis).

So for example (absmax is the desired function):

array = np.array([
    [ 5,  8,  2],
    [-7,  3,  0],
    [-2, -4, -1],
])
absmax(array, axis=0)  # [-7,  8,  2]
absmax(array, axis=1)  # [ 8, -7, -4]

I came up with the following implementation but it feels pretty clunky:

def absmax(a, *, axis):
    dims = list(a.shape)
    dims.pop(axis)
    indices = np.ogrid[tuple(slice(0, d) for d in dims)]
    argmax = np.abs(a).argmax(axis=axis)
    indices.insert((len(a.shape) + axis) % len(a.shape), argmax)
    return a[tuple(indices)]

So I'm wondering if there's a better / more concise way of achieving this functionality?

Upvotes: 0

Views: 1365

Answers (2)

norok2
norok2

Reputation: 26896

Perhaps a simpler approach is to use np.take_along_axis() to implement a lambda_max() function that accept a key parameter:

def lambda_max(arr, axis=None, key=None, keepdims=False):
    if callable(key):
        idxs = np.argmax(key(arr), axis)
        if axis is not None:
            idxs = np.expand_dims(idxs, axis)
            result = np.take_along_axis(arr, idxs, axis)
            if not keepdims:
                result = np.squeeze(result, axis=axis)
            return result
        else:
            return arr.flatten()[idxs]
    else:
        return np.amax(arr, axis)

This can be used as follow:

print(lambda_max(array, 0, np.abs))
# [-7  8  2]
print(lambda_max(array, 1, np.abs))
# [ 8 -7 -4]
print(lambda_max(array, None, np.abs))
# 8

Upvotes: 1

Divakar
Divakar

Reputation: 221624

In search of compact-ness, here's one that keeps dims -

def absmax(a, axis):
    s = np.array(a.shape)
    s[axis] = -1
    return np.take_along_axis(a,np.abs(a).argmax(axis).reshape(s),axis=axis)

Sample runs -

In [67]: a
Out[67]: 
array([[ 5,  8,  2],
       [-7,  3,  0],
       [-2, -4, -1]])

In [68]: absmax(a, axis=0)
Out[68]: array([[-7,  8,  2]])

In [69]: absmax(a, axis=1)
Out[69]: 
array([[ 8],
       [-7],
       [-4]])

If the extra dim look bother-some, add a reshape step to the output :

out = np.take_along_axis(a,np.abs(a).argmax(axis).reshape(s),axis=axis)
return out.reshape(np.delete(s,axis))

Sample runs on same input array -

In [89]: absmax(a, axis=0)
Out[89]: array([-7,  8,  2])

In [90]: absmax(a, axis=1)
Out[90]: array([ 8, -7, -4])

Upvotes: 1

Related Questions