Zarathustra
Zarathustra

Reputation: 411

Efficient pythonic way for this operation

I am looking for an efficient way to perform the following operation; here is a minimal working code:

import numpy as np
from scipy.signal import fftconvolve

n = 7
m = 100
N = 3000

a = np.random.rand( n,m,N ) + np.random.rand( n,m,N )*1j
b = np.random.rand( n,m,N ) + np.random.rand( n,m,N )*1j

# we want product over the n-dimension, with fftconvolve in the m-indices elementwise

old_shape= a.shape 
new_shape= ( n*m, N )
    
a  = a.reshape( new_shape ) 

for i in range( n ):
    
    b_tiled = np.tile( b[ i, :, : ], ( n, 1, 1 )).reshape( new_shape )
    result  = ( fftconvolve( b_tiled, a, mode="same", axes=-1 ) ).reshape( old_shape )
    
    result  = result.sum( axis=0 ) 

The operation calculates the FFT of the two arrays in a product-like way in the first index ( so I am avoiding a double loop over the range(n) indices, using just one ).

Upvotes: 0

Views: 44

Answers (1)

Bob
Bob

Reputation: 14654

Reference implementation

I will start wrapping the operation in a function so that I can easily compare implementations

def ref_impl(a,b):
    n,m,N = a.shape
    a  = a.reshape( m*n, N ) 
    ans = []
    for i in range( n ):
        b_tiled = np.tile( b[ i, :, : ], ( n, 1, 1 )).reshape( n*m, N )
        result  = ( fftconvolve( b_tiled, a, mode="same", axes=-1 ) ).reshape( n, m, N )
        ans.append(result.sum( axis=0 ))
    return np.array(ans);

Simplify sum of tiled array

Tile replicates one element, sum reduce all the elements in one axis.

Notice that b_tiled is constant in the first axis, you have some reshaping but everything is aligned, so the first result[k] = fftconvolve(b[i,:,:], a[k, :, :]) so the second result can be calculated as

result = sum(fftconvolve(b[i,:,:], a[k, :, :]) for k in range(n))

Since fftconvolve is linear this can be written as

result = fftconvolve(b[i,:,:], sum(a[k, :, :]for k in range(n)))

And this form can be vectorised

def impl1(a,b):
    n,m,N = a.shape
    ans = []
    for i in range( n ):
        result  = ( fftconvolve( b[i, :, :], a.sum(axis=0), mode="same", axes=-1 ) )
        ans.append(result)
    return np.array(ans); 

Simplify stack of slices

Index takes one element in one axis, stack will construct an array with multiple elements, this operation can be replaced by a single vectorized operation

def impl2(a,b):
    n,m,N = a.shape
    return ( fftconvolve( b, np.tile(a.sum(axis=0), (n, 1, 1)), mode="same", axes=-1 ) )

Complexity

The implementations proposed will use less scratch space, will get rid of the for loop and will do only one fftconvolve instead of n. In summary for large inputs it will run n times faster.

Upvotes: 1

Related Questions