adrin
adrin

Reputation: 4896

How to exclude rows/columns from numpy.ndarray data

Assume we have a numpy.ndarray data, let say with the shape (100,200), and you also have a list of indices which you want to exclude from the data. How would you do that? Something like this:

a = numpy.random.rand(100,200)
indices = numpy.random.randint(100,size=20)
b = a[-indices,:] # imaginary code, what to replace here?

Thanks.

Upvotes: 20

Views: 31158

Answers (4)

Bang
Bang

Reputation: 1132

You can use b = numpy.delete(a, indices, axis=0)

Source: NumPy docs.

Upvotes: 18

Thomas Arildsen
Thomas Arildsen

Reputation: 1310

You could try:

a = numpy.random.rand(100,200)
indices = numpy.random.randint(100,size=20)
b = a[np.setdiff1d(np.arange(100),indices),:]

This avoids creating the mask array of same size as your data in https://stackoverflow.com/a/21022753/865169. Note that this example creates a 2D array b instead of the flattened array in the latter answer.

A crude investigation of runtime vs memory cost of this approach vs https://stackoverflow.com/a/30273446/865169 seems to suggest that delete is faster while indexing with setdiff1d is much easier on memory consumption:

In [75]: %timeit b = np.delete(a, indices, axis=0)
The slowest run took 7.47 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 24.7 µs per loop

In [76]: %timeit c = a[np.setdiff1d(np.arange(100),indices),:]
10000 loops, best of 3: 48.4 µs per loop

In [77]: %memit b = np.delete(a, indices, axis=0)
peak memory: 52.27 MiB, increment: 0.85 MiB

In [78]: %memit c = a[np.setdiff1d(np.arange(100),indices),:]
peak memory: 52.39 MiB, increment: 0.12 MiB

Upvotes: 6

MB-F
MB-F

Reputation: 23647

You could try something like this:

a = numpy.random.rand(100,200)
indices = numpy.random.randint(100,size=20)
mask = numpy.ones(a.shape, dtype=bool)
mask[indices,:] = False
b = a[mask]

Upvotes: 1

Andrey Shokhin
Andrey Shokhin

Reputation: 12192

It's ugly but works:

b = np.array([a[i] for i in range(m.shape[0]) if i not in indices])

Upvotes: 3

Related Questions