solphy101
solphy101

Reputation: 137

kNN search on a very large dataset with a torchscript compatible library

I am trying to run a k-NN search for a very large dataset (1e5 points). Libraries like PyKeOps work fine in terms of memory and time but unfortunately it is not TorchScript compatible. Is there any other way, I can do this search and be memory efficient. Main requirement is that it should be TorchScript compatible.

Upvotes: -1

Views: 20

Answers (1)

Karl
Karl

Reputation: 5473

You can use torch.cdist to efficiently compute pairwise distances in a way that is torchscript compatible. This still requires computing (n, m) distances. If that is too memory intensive, the process can be broken down into batches.

class KNNSearch(nn.Module):
    def __init__(self, p=2):
        super(KNNSearch, self).__init__()
        self.p = float(p) 
        
    @torch.jit.export
    def compute_distances(self,
                          queries: torch.Tensor,   # [m, d] tensor
                          references: torch.Tensor # [n, d] tensor
                         ) -> torch.Tensor:        # [m, n] distances
        
        distances = torch.cdist(queries, references, p=self.p)
        return distances
    
    @torch.jit.export
    def knn(self,
            queries: torch.Tensor,                  # [m, d] tensor
            references: torch.Tensor,               # [n, d] tensor
            k: int                                  # k nearest neighbors
           ) -> Tuple[torch.Tensor, torch.Tensor]:  # [m, k] distances and [m, k] indices
        
        distances = self.compute_distances(queries, references)
        
        k_distances, k_indices = torch.topk(distances, k, dim=1, largest=False)
        return k_distances, k_indices
    
    @torch.jit.export
    def knn_batched(self,
                    queries: torch.Tensor,                 # [m, d] tensor
                    references: torch.Tensor,              # [n, d] tensor
                    k: int,                                # k nearest neighbors
                    query_batch_size: int,                 # query batch size
                    reference_batch_size: int              # reference batch size
                   ) -> Tuple[torch.Tensor, torch.Tensor]: # [m, k] distances and [m, k] indices

        k = min(k, references.shape[0])
        M = queries.shape[0]
        N = references.shape[0]
        dtype = queries.dtype
        device = queries.device
        
        indices = torch.zeros((M, k), dtype=torch.long, device=device)
        distances = torch.zeros((M, k), dtype=dtype, device=device)

        for q_start in range(0, M, query_batch_size):
            q_end = min(q_start+query_batch_size, M)
            q_batch = queries[q_start:q_end]
            batch_distances = torch.zeros((q_end-q_start, N), dtype=dtype, device=device)
            
            for r_start in range(0, N, reference_batch_size):
                r_end = min(r_start+reference_batch_size, N)
                r_batch = references[r_start:r_end]
                
                batch_distances[:, r_start:r_end] = self.compute_distances(q_batch, r_batch)

            batch_dist_k, batch_ind_k = torch.topk(batch_distances, k, dim=1, largest=False)
            indices[q_start:q_end] = batch_ind_k
            distances[q_start:q_end] = batch_dist_k
        
        return distances, indices

query_points = torch.randn(10000, 128).cuda()
reference_points = torch.randn(100000, 128).cuda()

knn = KNNSearch()
scripted_knn = torch.jit.script(knn)

res = scripted_knn.knn(query_points, reference_points, 3)
res = scripted_knn.knn_batched(query_points, reference_points, 3, 10000, 10000)

That said, torch.cdist is already using lower level kernels, so you are unlikely to see any benefit from scripting. If you need to push KNN performance, consider looking into HNSW indexing with libraries like HNSWlib, FAISS, or similar. These are not torchscript compatible, but will give much better performance.

Upvotes: 0

Related Questions