Cerin
Cerin

Reputation: 64739

Finding row index of max column values in a numpy matrix

Is there a fast efficient way to find the row in each column of an NxM Numpy array that has the highest value?

I'm currently doing this via a nested loop in Python, which is relatively slow:

from PIL import Image
import numpy as np
img = Image.open('sample.jpg').convert('L')
width, height = size = img.size
y = np.asarray(img.getdata(), dtype=np.float64).reshape((height, width))
max_rows = [0]*width
for col_i in xrange(y.shape[1]):
    max_vaue, max_row = max([(y[row_i][col_i], row_i) for row_i in xrange(y.shape[0])])
    max_rows[col_i] = max_row

For a 640x480 image, this takes about 5 seconds. Not huge, but more complex image operations, like blurring, implemented completely in Numpy/PIL/C take 0.01 seconds or less. This is an operation I'm trying to perform on a video stream, so it's a huge bottleneck. How do I speed this up, short of writing my own C extension?

Upvotes: 1

Views: 8099

Answers (1)

Suever
Suever

Reputation: 65430

You're going to want to use numpy.argmax for this. This will return the index of the element corresponding to the maximum value along a given axis.

row_index = np.argmax(y, axis=0)

# Alternately
row_index = y.argmax(axis=0)

And for the sake of an example

data = np.random.rand(4,2)
# array([[ 0.09695379,  0.44602826],
#        [ 0.73614533,  0.19700072],
#        [ 0.87843682,  0.21188487],
#        [ 0.11389634,  0.51628872]])

row_index = data.argmax(axis=0)
# array([2, 3])

Upvotes: 9

Related Questions