wingsofpanda
wingsofpanda

Reputation: 5

speed up python3 np.where operation

I wonder that is there anyway to speed up python3/numpy's np.where operation? I have a minimal working example as follows:

from time import time
import numpy as np

a = np.random.randint(0, 4, (768, 512, 512))
b = a.copy().astype(np.uint8)
c = b.copy()
print(a.shape)

ts = time()
a[a > 0] = 1
print(f'normalize in {time() - ts}s')

ts = time()
b[b > 0] = 1
print(f'normalize in {time() - ts}s')

ts = time()
c = np.where(c > 0, 1, c)
print(f'normalize in {time() - ts}s')

and the output is:

normalize in 0.9307191371917725s
normalize in 0.8891170024871826s
normalize in 0.7120938301086426s

as we can see that np.where gives the fastest result but it still pretty slow, as in my project I need to do such normalize operation about 50 times...

I am wondering that is there any faster way to do this? All I need is to convert any non-zero elements to 1. Thx!

Upvotes: 0

Views: 149

Answers (2)

Mark Setchell
Mark Setchell

Reputation: 207405

You could try with numexpr:

import numpy as np
import numexpr as ne

# for timing reference
a = np.random.randint(0, 4, (768, 512, 512))
%timeit a[a > 0] = 1

782 ms ± 9.42 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

a = np.random.randint(0, 4, (768, 512, 512))
%timeit a = ne.evaluate("where(a > 0, 1, a)")

254 ms ± 2.65 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Or with Numba:

import numba as nb

@nb.jit(nopython=True, fastmath=True, parallel=True)
def n(x):
    for a in nb.prange(x.shape[0]):
        for b in range(x.shape[1]):
            for c in range(x.shape[2]):
                if x[a,b,c]>0:
                    x[a,b,c]=1
    return x

# You should do this twice as it gets compiled first time through
%timeit c = n(a)
# Also, try re-assigning back on top of original, i.e. a=n(a)

113 ms ± 551 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Upvotes: 1

DYZ
DYZ

Reputation: 57033

The best way to accomplish your operation is to convert the matrix to the boolean datatype and then back to integer, especially if 8-bit integers are used:

import numpy as np
import timeit

a = np.random.randint(0, 4, (768, 512, 512))
a_short = a.astype(np.uint8)

# Time to initialize   
%timeit b=a.copy()
201 ms ± 561 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit b=a.copy(); b=b.astype(bool).astype(int)
373 ms ± 347 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit b=a.copy(); b=np.where(b>0,1,b)
985 ms ± 150 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit b=a.copy(); b[b>0]=1
1.09 s ± 1.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# Time to initialize    
%timeit b=a_short.copy()
26.7 ms ± 202 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

The best of the best:

%timeit b=a_short.copy(); b=b.astype(bool).astype(np.uint8)
77.5 ms ± 47.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit b=a_short.copy(); b=np.where(b>0,1,b)
570 ms ± 476 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit b=a_short.copy();b[b>0]=1
844 ms ± 4.09 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Upvotes: 2

Related Questions