Reputation: 2398
I have a list of sorted numpy arrays. What is the most efficient way to compute the sorted intersection of these arrays?
In my application, I expect the number of arrays to be less than 10^4, I expect the individual arrays to be of length less than 10^7, and I expect the length of the intersection to be close to p*N, where N is the length of the largest array and where 0.99 < p <= 1.0. The arrays are loaded from disk and can be loaded in batches if they won't all fit in memory at once.
A quick and dirty approach is to repeatedly invoke numpy.intersect1d()
. That seems inefficient though as intersect1d()
does not take advantage of the fact that the arrays are sorted.
Upvotes: 7
Views: 1816
Reputation: 5138
A few months ago, I wrote a C++-based python extension for this exact purpose. The package is called sortednp
and is available via pip. The intersection of multiple sorted numpy arrays, for example, a
, b
and c
, can be calculated with
import sortednp as snp
i = snp.kway_intersect(a, b, c)
By default, this uses an exponential search to advance the array indices internally which is pretty fast in cases where the intersection is small. In your case, it might be faster if you add algorithm=snp.SIMPLE_SEARCH
to the method call.
Upvotes: 2
Reputation: 18628
Since intersect1d
sort arrays each time, it's effectively inefficient.
Here you have to sweep intersection and each sample together to build the new intersection, which can be done in linear time, maintaining order.
Such task must often be tuned by hand with low level routines.
Here a way to do that with numba
:
from numba import njit
import numpy as np
@njit
def drop_missing(intersect,sample):
i=j=k=0
new_intersect=np.empty_like(intersect)
while i< intersect.size and j < sample.size:
if intersect[i]==sample[j]: # the 99% case
new_intersect[k]=intersect[i]
k+=1
i+=1
j+=1
elif intersect[i]<sample[j]:
i+=1
else :
j+=1
return new_intersect[:k]
Now the samples :
n=10**7
ref=np.random.randint(0,n,n)
ref.sort()
def perturbation(sample,k):
rands=np.random.randint(0,n,k-1)
rands.sort()
l=np.split(sample,rands)
return np.concatenate([a[:-1] for a in l])
samples=[perturbation(ref,100) for _ in range(10)] #similar samples
And a run for 10 samples
def find_intersect(samples):
intersect=samples[0]
for sample in samples[1:]:
intersect=drop_missing(intersect,sample)
return intersect
In [18]: %time u=find_intersect(samples)
Wall time: 307 ms
In [19]: len(u)
Out[19]: 9999009
This way it seems that the job can be done in about 5 minutes , beyond loading time.
Upvotes: 1