postnubilaphoebus
postnubilaphoebus

Reputation: 136

More efficient way of looping over a multidimensional numpy array other than numpy.where

I have a nested array of shape: [200, 500, 1000]. Each index represents a coordinate of an image, eg array[1, 2, 3] would give me the value of the array at x=1, y=2, and z=3 in coordinate space. I have repeating values in the array in the range from 1-20,000. My goal is to find the x, y, and z coordinates of each of these values (they occur multiple times). Normally, I'd just loop over the array and ask at each point current_values = np.where(arr==current_index) However, it's slow. So I tried cupy instead (current_values = cp.where(arr==current_index)), but it sometimes doesn't work at around the same index. For example, current_values = cp.where(arr==current_index) yields an empty array at current_index==760 on one run, and current_index==780 on another run. Keep in mind I am using the same data, so I am quite puzzled by this behaviour. The amount of errors is small I assume, because this only happens thrice in the loop, but it obviously still impacts my accuracy.

I have two questions:

  1. Can I smartly restack my array and use a different function from .where so I don't have to use cupy?
  2. What is causing these errors with cupy.where?

Upvotes: 0

Views: 135

Answers (3)

Frank Yellin
Frank Yellin

Reputation: 11230

A completely different solution:

    array = np.random.randint(1, 20_001, (200, 500, 1000))
    flat_array = array.reshape(-1) # does not copy array
    sorted_index = np.argsort(flat_array)

    # Suppose you want to find all elements that have value 100
    left = bisect.bisect_left(sorted_index, 100, key=lambda x: flat_array[x])
    right = bisect.bisect_right(sorted_index, 100, key=lambda x: flat_array[x])

At this point, flat_array[sorted_index[i]] = 100 for all values of i such that left <= i < right.

Each sorted_index[i] can be converted into a 3-dimensional index using:

np.unravel_index(sorted_index[i], array.shape)

So you also that that

array[np.unravel_index(sorted_index[i], array.shape)] = 100

for that same range of indices.

Note: By default np.argsort uses an unstable sorting algorithm. If it's important that the indices be in increasing order, add kind='stable' to the sort call.


Changed to use np.unravel_index as suggested by @hpaulj. Fixed typo where I used temp instead of sorted_index.

Upvotes: 3

Carlos Horn
Carlos Horn

Reputation: 1293

You could use numba to collect the indices in one pass.

See for example the following code:

import numpy as np
from numba import njit
from numba.typed import List

@njit
def indices(data):
    nk, nj, ni = data.shape
    out = {}
    _data = data.ravel()
    for i in range(data.size):
        value = _data[i]
        if value not in out:
            out[value] = List.empty_list(np.int64)
        out[value].append(np.int64(i))
    return out

data = np.random.randint(0, 2000, size=(20,50,100), dtype=np.uint16)
res = {
    key: np.unravel_index(np.array(value), data.shape) 
    for key, value in indices(data).items()
}
print(res)

Note that I used the flat index in numba and unravel to the original shape later to make the typing easier in nopython mode of the numba jit.

Upvotes: 0

Frank Yellin
Frank Yellin

Reputation: 11230

It's still pretty slow, but you are asking about 10 million indices.

You could do something like:

    array = np.random.randint(1, 20_001, (200, 500, 1000))
    result = collections.defaultdict(list)
    it = np.nditer(array, flags=['multi_index'])
    for value in it:
        result[value.item()].append(it.multi_index)

result[value] will now contain all the indices that have value.

Upvotes: 0

Related Questions