SeF
SeF

Reputation: 4160

Fast way to check if a numpy array is binary (contains only 0 and 1)

Given a numpy array, how can I figure it out if it contains only 0 and 1 quickly? Is there any implemented method?

Upvotes: 15

Views: 10612

Answers (7)

Bram Schijvenaars
Bram Schijvenaars

Reputation: 21

The following should work:

ans = set(arr).issubset([0,1])

Upvotes: 2

Mai Hai
Mai Hai

Reputation: 1330

We could use np.isin().

input_array = input_array.squeeze(-1)
is_binary   = np.isin(input_array, [0,1]).all()

1st line:
squeeze to unroll the input array, as we don't want to deal with the complication of np.isin() with a multi-dimension array.

2nd line:
np.isin() checks whether all elements of input belong to 0 or 1.
np.isin() returns a list of [True, False, True, True..].
Then all() to ensure that list contain all True.

Upvotes: 3

ahmedhosny
ahmedhosny

Reputation: 1177

How about numpy unique?

np.unique(arr)

Should return [0,1] if binary.

Upvotes: 1

user7138814
user7138814

Reputation: 2041

With only a single loop over the data:

0 <= np.bitwise_or.reduce(ar) <= 1

Note that this doesn't work for floating point dtype.

If the values are guaranteed non-negative you can get short-circuiting behavior:

try:
    np.empty((2,), bool)[ar]
    is_binary = True
except IndexError:
    is_binary = False

This method (always) allocates a temp array of the same shape as the argument and seems to loop over the data slower than the first method.

Upvotes: 4

Thomas Baruchel
Thomas Baruchel

Reputation: 7517

It looks you can achieve it with something like:

np.array_equal(a, a.astype(bool))

If your array is large, it should avoid copying too many arrays (as in some other answers). Thus, it should probably be slightly faster than other answers (not tested however).

Upvotes: 13

JoshAdel
JoshAdel

Reputation: 68682

If you have access to Numba (or alternatively cython), you can write something like the following, which will be significantly faster for catching non-binary arrays since it will short circuit the calculation/stop immediately instead of continuing with all of the elements:

import numpy as np
import numba as nb

@nb.njit
def check_binary(x):
    is_binary = True
    for v in np.nditer(x):
        if v.item() != 0 and v.item() != 1:
            is_binary = False
            break

    return is_binary

Running this in pure python without the aid of an accelerator like Numba or Cython makes this approach prohibitively slow.

Timings:

a = np.random.randint(0,2,(3000,3000)) # Only 0s and 1s

%timeit ((a==0) | (a==1)).all()
# 100 loops, best of 3: 15.1 ms per loop

%timeit check_binary(a)
# 100 loops, best of 3: 11.6 ms per loop

a = np.random.randint(0,3,(3000,3000)) # Contains 2 as well

%timeit ((a==0) | (a==1)).all()
# 100 loops, best of 3: 14.9 ms per loop

%timeit check_binary(a)
# 1000000 loops, best of 3: 543 ns per loop

Upvotes: 3

Divakar
Divakar

Reputation: 221564

Few approaches -

((a==0) | (a==1)).all()
~((a!=0) & (a!=1)).any()
np.count_nonzero((a!=0) & (a!=1))==0
a.size == np.count_nonzero((a==0) | (a==1))

Runtime test -

In [313]: a = np.random.randint(0,2,(3000,3000)) # Only 0s and 1s

In [314]: %timeit ((a==0) | (a==1)).all()
     ...: %timeit ~((a!=0) & (a!=1)).any()
     ...: %timeit np.count_nonzero((a!=0) & (a!=1))==0
     ...: %timeit a.size == np.count_nonzero((a==0) | (a==1))
     ...: 
10 loops, best of 3: 28.8 ms per loop
10 loops, best of 3: 29.3 ms per loop
10 loops, best of 3: 28.9 ms per loop
10 loops, best of 3: 28.8 ms per loop

In [315]: a = np.random.randint(0,3,(3000,3000)) # Contains 2 as well

In [316]: %timeit ((a==0) | (a==1)).all()
     ...: %timeit ~((a!=0) & (a!=1)).any()
     ...: %timeit np.count_nonzero((a!=0) & (a!=1))==0
     ...: %timeit a.size == np.count_nonzero((a==0) | (a==1))
     ...: 
10 loops, best of 3: 28 ms per loop
10 loops, best of 3: 27.5 ms per loop
10 loops, best of 3: 29.1 ms per loop
10 loops, best of 3: 28.9 ms per loop

Their runtimes seem to be comparable.

Upvotes: 15

Related Questions