Reputation: 36299
Basically I'd like to do the following Python equivalent for Numpy arrays of arbitrary dimension and specifying an arbitrary axis:
max(array, key=abs)
i.e. select elements based on the maximum absolute value (similar to how array.max(axis=axis)
selects just the maximum value along a specific axis).
So for example (absmax
is the desired function):
array = np.array([
[ 5, 8, 2],
[-7, 3, 0],
[-2, -4, -1],
])
absmax(array, axis=0) # [-7, 8, 2]
absmax(array, axis=1) # [ 8, -7, -4]
I came up with the following implementation but it feels pretty clunky:
def absmax(a, *, axis):
dims = list(a.shape)
dims.pop(axis)
indices = np.ogrid[tuple(slice(0, d) for d in dims)]
argmax = np.abs(a).argmax(axis=axis)
indices.insert((len(a.shape) + axis) % len(a.shape), argmax)
return a[tuple(indices)]
So I'm wondering if there's a better / more concise way of achieving this functionality?
Upvotes: 0
Views: 1365
Reputation: 26896
Perhaps a simpler approach is to use np.take_along_axis()
to implement a lambda_max()
function that accept a key
parameter:
def lambda_max(arr, axis=None, key=None, keepdims=False):
if callable(key):
idxs = np.argmax(key(arr), axis)
if axis is not None:
idxs = np.expand_dims(idxs, axis)
result = np.take_along_axis(arr, idxs, axis)
if not keepdims:
result = np.squeeze(result, axis=axis)
return result
else:
return arr.flatten()[idxs]
else:
return np.amax(arr, axis)
This can be used as follow:
print(lambda_max(array, 0, np.abs))
# [-7 8 2]
print(lambda_max(array, 1, np.abs))
# [ 8 -7 -4]
print(lambda_max(array, None, np.abs))
# 8
Upvotes: 1
Reputation: 221624
In search of compact-ness, here's one that keeps dims -
def absmax(a, axis):
s = np.array(a.shape)
s[axis] = -1
return np.take_along_axis(a,np.abs(a).argmax(axis).reshape(s),axis=axis)
Sample runs -
In [67]: a
Out[67]:
array([[ 5, 8, 2],
[-7, 3, 0],
[-2, -4, -1]])
In [68]: absmax(a, axis=0)
Out[68]: array([[-7, 8, 2]])
In [69]: absmax(a, axis=1)
Out[69]:
array([[ 8],
[-7],
[-4]])
If the extra dim look bother-some, add a reshape step to the output :
out = np.take_along_axis(a,np.abs(a).argmax(axis).reshape(s),axis=axis)
return out.reshape(np.delete(s,axis))
Sample runs on same input array -
In [89]: absmax(a, axis=0)
Out[89]: array([-7, 8, 2])
In [90]: absmax(a, axis=1)
Out[90]: array([ 8, -7, -4])
Upvotes: 1