mrapacz
mrapacz

Reputation: 1019

Split a 2D array of values to X and y datasets

I have the following np.ndarray:

>>> arr
array([[1, 2],
       [3, 4]])

I would like to split it to X and y where each array is coordinates and values, respectively. So far I managed to solve this using np.ndenumerate:

>>> X, y = zip(*np.ndenumerate(arr))
>>> X 
((0, 0), (0, 1), (1, 0), (1, 1))
>>> y
(1, 2, 3, 4)

I'm wondering if there's a more idiomatic and faster way to achieve it, since the arrays I'm actually dealing with have millions of values.

I need the X and y array to pass them to a sklearn classifier later. The formats above seemed the most natural for me, but perhaps there's a better way I can pass them to the fit function.

Upvotes: 1

Views: 1555

Answers (2)

Andrzej Pisarek
Andrzej Pisarek

Reputation: 271

Reshaping arr to y is easy, you can achieve it by y = arr.flatten(). I suggest treating generating X as a separate task.

Let's assume that your dataset is of shape NxM. In our benchmark we set N to 500 and M to 1000.

N = 500
M = 1000
arr = np.random.randn(N, M)

Then by using np.mgrid and transforming indices you can get the result as:

np.mgrid[:N, :M].transpose(1, 2, 0).reshape(-1, 2)

Benchmarks:

%timeit np.mgrid[:N, :M].transpose(1, 2, 0).reshape(-1, 2)
# 3.11 ms ± 35.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit zip(*np.ndenumerate(arr))
# 235 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In your case you can unpack and get N and M by:

N, M = arr.shape

and then:

X = np.mgrid[:N, :M].transpose(1, 2, 0).reshape(-1, 2)

Upvotes: 1

Chris
Chris

Reputation: 29742

Use numpy.where with numpy.ravel():

import numpy as np

def ndenumerate(np_array):
    return list(zip(*np.where(np_array+1))), np_array.ravel()

arr = np.random.randint(0, 100, (1000,1000))

X_new, y_new = ndenumerate(arr)
X,y = zip(*np.ndenumerate(arr))

Output (validation):

all(i1 == i2 for i1, i2 in zip(X, X_new))
# True
all(y == y_new)
# True

Benchmark (about 3x faster):

%timeit ndenumerate(arr)
# 234 ms ± 20.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit zip(*np.ndenumerate(arr))
# 877 ms ± 91.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Upvotes: 1

Related Questions