dshin
dshin

Reputation: 2398

Intersection of sorted numpy arrays

I have a list of sorted numpy arrays. What is the most efficient way to compute the sorted intersection of these arrays?

In my application, I expect the number of arrays to be less than 10^4, I expect the individual arrays to be of length less than 10^7, and I expect the length of the intersection to be close to p*N, where N is the length of the largest array and where 0.99 < p <= 1.0. The arrays are loaded from disk and can be loaded in batches if they won't all fit in memory at once.

A quick and dirty approach is to repeatedly invoke numpy.intersect1d(). That seems inefficient though as intersect1d() does not take advantage of the fact that the arrays are sorted.

Upvotes: 7

Views: 1816

Answers (2)

sauerburger
sauerburger

Reputation: 5138

A few months ago, I wrote a C++-based python extension for this exact purpose. The package is called sortednp and is available via pip. The intersection of multiple sorted numpy arrays, for example, a, b and c, can be calculated with

import sortednp as snp
i = snp.kway_intersect(a, b, c)

By default, this uses an exponential search to advance the array indices internally which is pretty fast in cases where the intersection is small. In your case, it might be faster if you add algorithm=snp.SIMPLE_SEARCH to the method call.

Upvotes: 2

B. M.
B. M.

Reputation: 18628

Since intersect1d sort arrays each time, it's effectively inefficient.

Here you have to sweep intersection and each sample together to build the new intersection, which can be done in linear time, maintaining order.

Such task must often be tuned by hand with low level routines.

Here a way to do that with numba :

from numba import njit
import numpy as np

@njit
def drop_missing(intersect,sample):
    i=j=k=0
    new_intersect=np.empty_like(intersect)
    while i< intersect.size and j < sample.size:
            if intersect[i]==sample[j]: # the 99% case
                new_intersect[k]=intersect[i]
                k+=1
                i+=1
                j+=1
            elif intersect[i]<sample[j]:
                i+=1
            else : 
                j+=1
    return new_intersect[:k]  

Now the samples :

n=10**7
ref=np.random.randint(0,n,n)  
ref.sort()

def perturbation(sample,k):
    rands=np.random.randint(0,n,k-1)
    rands.sort()
    l=np.split(sample,rands)
    return np.concatenate([a[:-1] for a in l])

samples=[perturbation(ref,100) for  _ in range(10)] #similar samples 

And a run for 10 samples

def find_intersect(samples):
    intersect=samples[0]
    for sample in samples[1:]:
        intersect=drop_missing(intersect,sample)
    return intersect                

In [18]: %time u=find_intersect(samples)
Wall time: 307 ms

In [19]: len(u)
Out[19]: 9999009     

This way it seems that the job can be done in about 5 minutes , beyond loading time.

Upvotes: 1

Related Questions