Reputation: 1019
I have the following np.ndarray:
>>> arr
array([[1, 2],
[3, 4]])
I would like to split it to X and y where each array is coordinates and values, respectively. So far I managed to solve this using np.ndenumerate:
>>> X, y = zip(*np.ndenumerate(arr))
>>> X
((0, 0), (0, 1), (1, 0), (1, 1))
>>> y
(1, 2, 3, 4)
I'm wondering if there's a more idiomatic and faster way to achieve it, since the arrays I'm actually dealing with have millions of values.
I need the X and y array to pass them to a sklearn classifier later. The formats above seemed the most natural for me, but perhaps there's a better way I can pass them to the fit
function.
Upvotes: 1
Views: 1555
Reputation: 271
Reshaping arr
to y
is easy, you can achieve it by y = arr.flatten()
. I suggest treating generating X
as a separate task.
Let's assume that your dataset is of shape NxM
. In our benchmark we set N
to 500 and M
to 1000.
N = 500
M = 1000
arr = np.random.randn(N, M)
Then by using np.mgrid
and transforming indices you can get the result as:
np.mgrid[:N, :M].transpose(1, 2, 0).reshape(-1, 2)
Benchmarks:
%timeit np.mgrid[:N, :M].transpose(1, 2, 0).reshape(-1, 2)
# 3.11 ms ± 35.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit zip(*np.ndenumerate(arr))
# 235 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In your case you can unpack and get N
and M
by:
N, M = arr.shape
and then:
X = np.mgrid[:N, :M].transpose(1, 2, 0).reshape(-1, 2)
Upvotes: 1
Reputation: 29742
Use numpy.where
with numpy.ravel()
:
import numpy as np
def ndenumerate(np_array):
return list(zip(*np.where(np_array+1))), np_array.ravel()
arr = np.random.randint(0, 100, (1000,1000))
X_new, y_new = ndenumerate(arr)
X,y = zip(*np.ndenumerate(arr))
Output (validation):
all(i1 == i2 for i1, i2 in zip(X, X_new))
# True
all(y == y_new)
# True
Benchmark (about 3x faster):
%timeit ndenumerate(arr)
# 234 ms ± 20.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit zip(*np.ndenumerate(arr))
# 877 ms ± 91.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Upvotes: 1