Victor Ferraz
Victor Ferraz

Reputation: 39

What is the best pythonic solution?

Being X an array shaped (n,m) and Y a list with length = n which values are binaries, what is the best pythonic alternative for the following code, using numpy?

p1 = np.zeros(X.shape[1])
p0 = np.zeros(X.shape[1])
for i in range(len(X[0])):        
        sum_1 = np.where(Y==1,X[:,i],0).sum()
        sum_0 = np.where(Y==0,X[:,i],0).sum()
        p1[i] = sum_1
        p0[i] = sum_0
    

Upvotes: 3

Views: 121

Answers (3)

Eric Duminil
Eric Duminil

Reputation: 54283

If Y is a boolean area, you can use Y as an index directly, and numpy masks the corresponding rows. You can also negate Y with ~Y, in order to get the other rows:

>>> X
array([[1, 2],
       [2, 3],
       [3, 4]])
>>> Y
array([False, False,  True])
>>> X[Y]
array([[3, 4]])
>>> X[~Y]
array([[1, 2],
       [2, 3]])
>>> X[Y].sum(axis=0)
array([3, 4])
>>> X[~Y].sum(axis=0)
array([3, 5])

Upvotes: 0

nneonneo
nneonneo

Reputation: 179592

Here's a faster and simpler version:

p1 = X.T @ Y # or np.dot(X.T, Y) if on Python < 3.5
p0 = X.T @ (1 - Y)

This takes advantage of the fact that your Y array is zeros and ones and computes a fast dot-product.


Timing results with the following framework:

import numpy as np

n = 2000
m = 1000
X = np.random.random((n, m))
Y = (np.random.random((n,)) > 0.5).astype(int)

def v0():
    p1 = np.zeros(X.shape[1])
    p0 = np.zeros(X.shape[1])
    for i in range(len(X[0])):
        sum_1 = np.where(Y==1,X[:,i],0).sum()
        sum_0 = np.where(Y==0,X[:,i],0).sum()
        p1[i] = sum_1
        p0[i] = sum_0
    return p0, p1

def v1():
    p1 = np.sum(X[np.where(Y==1)], axis=0)
    p0 = np.sum(X[np.where(Y==0)], axis=0)
    return p0, p1

def v2():
    p1 = X.T @ Y # or np.dot(X.T, Y) if on Python < 3.5
    p0 = X.T @ (1 - Y)
    return p0, p1

p0_0, p1_0 = v0()
p0_1, p1_1 = v1()
p0_2, p1_2 = v2()
assert np.allclose(p0_0, p0_1)
assert np.allclose(p0_0, p0_2)
assert np.allclose(p1_0, p1_1)
assert np.allclose(p1_0, p1_2)
$ python3 -m timeit -s 'import test' 'test.v0()'
10 loops, best of 5: 33.5 msec per loop
$ python3 -m timeit -s 'import test' 'test.v1()'
100 loops, best of 5: 3.81 msec per loop
$ python3 -m timeit -s 'import test' 'test.v2()'
500 loops, best of 5: 794 usec per loop

This version is over 40x faster than your original for this set of sizes.

Upvotes: 5

EuxhenH
EuxhenH

Reputation: 161

You're summing over the first axis of X under some condition for the rows

p1 = np.sum(X[np.where(Y==1)], axis=0)
p0 = np.sum(X[np.where(Y==0)], axis=0)

Upvotes: 1

Related Questions