Reputation: 4160
Given a numpy array, how can I figure it out if it contains only 0 and 1 quickly? Is there any implemented method?
Upvotes: 15
Views: 10612
Reputation: 1330
We could use np.isin()
.
input_array = input_array.squeeze(-1)
is_binary = np.isin(input_array, [0,1]).all()
1st line:
squeeze
to unroll the input array, as we don't want to deal with the complication of np.isin()
with a multi-dimension array.
2nd line:
np.isin()
checks whether all elements of input belong to 0 or 1.
np.isin()
returns a list of [True, False, True, True..].
Then all()
to ensure that list contain all True.
Upvotes: 3
Reputation: 1177
How about numpy unique?
np.unique(arr)
Should return [0,1] if binary.
Upvotes: 1
Reputation: 2041
With only a single loop over the data:
0 <= np.bitwise_or.reduce(ar) <= 1
Note that this doesn't work for floating point dtype.
If the values are guaranteed non-negative you can get short-circuiting behavior:
try:
np.empty((2,), bool)[ar]
is_binary = True
except IndexError:
is_binary = False
This method (always) allocates a temp array of the same shape as the argument and seems to loop over the data slower than the first method.
Upvotes: 4
Reputation: 7517
It looks you can achieve it with something like:
np.array_equal(a, a.astype(bool))
If your array is large, it should avoid copying too many arrays (as in some other answers). Thus, it should probably be slightly faster than other answers (not tested however).
Upvotes: 13
Reputation: 68682
If you have access to Numba (or alternatively cython), you can write something like the following, which will be significantly faster for catching non-binary arrays since it will short circuit the calculation/stop immediately instead of continuing with all of the elements:
import numpy as np
import numba as nb
@nb.njit
def check_binary(x):
is_binary = True
for v in np.nditer(x):
if v.item() != 0 and v.item() != 1:
is_binary = False
break
return is_binary
Running this in pure python without the aid of an accelerator like Numba or Cython makes this approach prohibitively slow.
Timings:
a = np.random.randint(0,2,(3000,3000)) # Only 0s and 1s
%timeit ((a==0) | (a==1)).all()
# 100 loops, best of 3: 15.1 ms per loop
%timeit check_binary(a)
# 100 loops, best of 3: 11.6 ms per loop
a = np.random.randint(0,3,(3000,3000)) # Contains 2 as well
%timeit ((a==0) | (a==1)).all()
# 100 loops, best of 3: 14.9 ms per loop
%timeit check_binary(a)
# 1000000 loops, best of 3: 543 ns per loop
Upvotes: 3
Reputation: 221564
Few approaches -
((a==0) | (a==1)).all()
~((a!=0) & (a!=1)).any()
np.count_nonzero((a!=0) & (a!=1))==0
a.size == np.count_nonzero((a==0) | (a==1))
Runtime test -
In [313]: a = np.random.randint(0,2,(3000,3000)) # Only 0s and 1s
In [314]: %timeit ((a==0) | (a==1)).all()
...: %timeit ~((a!=0) & (a!=1)).any()
...: %timeit np.count_nonzero((a!=0) & (a!=1))==0
...: %timeit a.size == np.count_nonzero((a==0) | (a==1))
...:
10 loops, best of 3: 28.8 ms per loop
10 loops, best of 3: 29.3 ms per loop
10 loops, best of 3: 28.9 ms per loop
10 loops, best of 3: 28.8 ms per loop
In [315]: a = np.random.randint(0,3,(3000,3000)) # Contains 2 as well
In [316]: %timeit ((a==0) | (a==1)).all()
...: %timeit ~((a!=0) & (a!=1)).any()
...: %timeit np.count_nonzero((a!=0) & (a!=1))==0
...: %timeit a.size == np.count_nonzero((a==0) | (a==1))
...:
10 loops, best of 3: 28 ms per loop
10 loops, best of 3: 27.5 ms per loop
10 loops, best of 3: 29.1 ms per loop
10 loops, best of 3: 28.9 ms per loop
Their runtimes seem to be comparable.
Upvotes: 15