usher96
usher96

Reputation: 37

Vectorizing for loops in python with numpy multidimensional arrays

I'm trying to improve the performance of this code below. Eventually it will be using much bigger arrays but I thought I would start of with something simple that works then look at where is is slow, optimise it then try it out on the full size. Here is the original code:

#Minimum example with random variables
import numpy as np
import matplotlib.pyplot as plt

n=4
# Theoretical Travel Time to each station
ttable=np.array([1,2,3,4])
# Seismic traces,measured at each station
traces=np.random.random((n, 506))
dt=0.1
# Forward Problem add energy to each trace at the deserired time from a given origin time
given_origin_time=1
for i in range(n):
    # Energy will arrive at the sample equivelant to origin time + travel time
    arrival_sample=int(round((given_origin_time+ttable[i])/dt))
    traces[i,arrival_sample]=2

# The aim is to find the origin time by trying each possible origin time and adding the energy up. 
# Where this "Stack" is highest is likely to be the origin time

# Find the maximum travel time
tmax=ttable.max()


# We pad the traces to avoid when we shift by a travel time that the trace has no value
traces=np.lib.pad(traces,((0,0),(round(tmax/dt),round(tmax/dt))),'constant',constant_values=0)

#Available origin times to search for relative to the beginning of the trace
origin_times=np.linspace(-tmax,len(traces),len(traces)+round(tmax/dt))

# Create an empty array to fill with our stack
S=np.empty((origin_times.shape[0]))

# Loop over all the potential origin times
for l,otime in enumerate(origin_times):
    # Create some variables which we will sum up over all stations
    sum_point=0
    sqrr_sum_point=0
    # Loop over each station
    for m in range(n):
        # Find the appropriate travel time
        ttime=ttable[m] 
        # Grap the point on the trace that corresponds to this travel time + the origin time we are searching for 
        point=traces[m,int(round((tmax+otime+ttime)/dt))]
        # Sum up the points
        sum_point+=point
        # Sum of the square of the points
        sqrr_sum_point+=point**2
    # Create the stack by taking the square of the sums dived by sum of the squares normalised by the number of stations
    S[l]=sum_point#**2/(n*sqrr_sum_point)

# Plot the output the peak should be at given_origin_time
plt.plot(origin_times,S)
plt.show()

I think the problem i dont understand the broacasting and indexing of multidimensional arrays. After this I will be extended the dimensions to search for x,y,z which would be given by increaseing the dimension ttable. I will probably try and implement either pytables or np.memmap to help with the large arrays.

Upvotes: 0

Views: 165

Answers (1)

perimosocordiae
perimosocordiae

Reputation: 17797

With some quick profiling, it appears that the line

point=traces[m,int(round((tmax+otime+ttime)/dt))]

is taking ~40% of the total program's runtime. Let's see if we can vectorize it a bit:

    ttime_inds = np.around((tmax + otime + ttable) / dt).astype(int)
    # Loop over each station
    for m in range(n):
        # Grap the point on the trace that corresponds to this travel time + the origin time we are searching for 
        point=traces[m,ttime_inds[m]]

We noticed that the only thing changing in the loop (other than m) was ttime, so we pulled it out and vectorized that part using numpy functions.

That was the biggest hotspot, but we can go a bit further and remove the inner loop entirely:

# Loop over all the potential origin times
for l,otime in enumerate(origin_times):
    ttime_inds = np.around((tmax + otime + ttable) / dt).astype(int)
    points = traces[np.arange(n),ttime_inds]
    sum_point = points.sum()
    sqrr_sum_point = (points**2).sum()
    # Create the stack by taking the square of the sums dived by sum of the squares normalised by the number of stations
    S[l]=sum_point#**2/(n*sqrr_sum_point)

EDIT: If you want to remove the outer loop as well, we need to pull otime out:

ttime_inds = np.around((tmax + origin_times[:,None] + ttable) / dt).astype(int)

Then, we proceed as before, summing over the second axis:

points = traces[np.arange(n),ttime_inds]
sum_points = points.sum(axis=1)
sqrr_sum_points = (points**2).sum(axis=1)
S = sum_points # **2/(n*sqrr_sum_points)

Upvotes: 2

Related Questions