Intersection of sorted numpy arrays

Question

I have a list of sorted numpy arrays. What is the most efficient way to compute the sorted intersection of these arrays?

In my application, I expect the number of arrays to be less than 10^4, I expect the individual arrays to be of length less than 10^7, and I expect the length of the intersection to be close to p*N, where N is the length of the largest array and where 0.99 < p <= 1.0. The arrays are loaded from disk and can be loaded in batches if they won't all fit in memory at once.

A quick and dirty approach is to repeatedly invoke numpy.intersect1d(). That seems inefficient though as intersect1d() does not take advantage of the fact that the arrays are sorted.


Show source
| numpy   2017-10-04 21:10 1 Answers

Answers to Intersection of sorted numpy arrays ( 1 )

  1. 2017-10-04 22:10

    Since intersect1d sort arrays each time, it's effectively inefficient.

    Here you have to sweep intersection and each sample together to build the new intersection, which can be done in linear time, maintaining order.

    Such task must often be tuned by hand with low level routines.

    Here a way to do that with numba :

    from numba import njit
    import numpy as np
    
    @njit
    def drop_missing(intersect,sample):
        i=j=k=0
        new_intersect=np.empty_like(intersect)
        while i< intersect.size and j < sample.size:
                if intersect[i]==sample[j]: # the 99% case
                    new_intersect[k]=intersect[i]
                    k+=1
                    i+=1
                    j+=1
                elif intersect[i]<sample[j]:
                    i+=1
                else : 
                    j+=1
        return new_intersect[:k]  
    

    Now the samples :

    n=10**7
    ref=np.random.randint(0,n,n)  
    ref.sort()
    
    def perturbation(sample,k):
        rands=np.random.randint(0,n,k-1)
        rands.sort()
        l=np.split(sample,rands)
        return np.concatenate([a[:-1] for a in l])
    
    samples=[perturbation(ref,100) for  _ in range(10)] #similar samples 
    

    And a run for 10 samples

    def find_intersect(samples):
        intersect=samples[0]
        for sample in samples[1:]:
            intersect=drop_missing(intersect,sample)
        return intersect                
    
    In [18]: %time u=find_intersect(samples)
    Wall time: 307 ms
    
    In [19]: len(u)
    Out[19]: 9999009     
    

    This way it seems that the job can be done in about 5 minutes , beyond loading time.

Leave a reply to - Intersection of sorted numpy arrays

◀ Go back