Reputation: 137
I am trying to run a k-NN search for a very large dataset (1e5 points). Libraries like PyKeOps work fine in terms of memory and time but unfortunately it is not TorchScript compatible. Is there any other way, I can do this search and be memory efficient. Main requirement is that it should be TorchScript compatible.
Upvotes: -1
Views: 20
Reputation: 5473
You can use torch.cdist
to efficiently compute pairwise distances in a way that is torchscript compatible. This still requires computing (n, m)
distances. If that is too memory intensive, the process can be broken down into batches.
class KNNSearch(nn.Module):
def __init__(self, p=2):
super(KNNSearch, self).__init__()
self.p = float(p)
@torch.jit.export
def compute_distances(self,
queries: torch.Tensor, # [m, d] tensor
references: torch.Tensor # [n, d] tensor
) -> torch.Tensor: # [m, n] distances
distances = torch.cdist(queries, references, p=self.p)
return distances
@torch.jit.export
def knn(self,
queries: torch.Tensor, # [m, d] tensor
references: torch.Tensor, # [n, d] tensor
k: int # k nearest neighbors
) -> Tuple[torch.Tensor, torch.Tensor]: # [m, k] distances and [m, k] indices
distances = self.compute_distances(queries, references)
k_distances, k_indices = torch.topk(distances, k, dim=1, largest=False)
return k_distances, k_indices
@torch.jit.export
def knn_batched(self,
queries: torch.Tensor, # [m, d] tensor
references: torch.Tensor, # [n, d] tensor
k: int, # k nearest neighbors
query_batch_size: int, # query batch size
reference_batch_size: int # reference batch size
) -> Tuple[torch.Tensor, torch.Tensor]: # [m, k] distances and [m, k] indices
k = min(k, references.shape[0])
M = queries.shape[0]
N = references.shape[0]
dtype = queries.dtype
device = queries.device
indices = torch.zeros((M, k), dtype=torch.long, device=device)
distances = torch.zeros((M, k), dtype=dtype, device=device)
for q_start in range(0, M, query_batch_size):
q_end = min(q_start+query_batch_size, M)
q_batch = queries[q_start:q_end]
batch_distances = torch.zeros((q_end-q_start, N), dtype=dtype, device=device)
for r_start in range(0, N, reference_batch_size):
r_end = min(r_start+reference_batch_size, N)
r_batch = references[r_start:r_end]
batch_distances[:, r_start:r_end] = self.compute_distances(q_batch, r_batch)
batch_dist_k, batch_ind_k = torch.topk(batch_distances, k, dim=1, largest=False)
indices[q_start:q_end] = batch_ind_k
distances[q_start:q_end] = batch_dist_k
return distances, indices
query_points = torch.randn(10000, 128).cuda()
reference_points = torch.randn(100000, 128).cuda()
knn = KNNSearch()
scripted_knn = torch.jit.script(knn)
res = scripted_knn.knn(query_points, reference_points, 3)
res = scripted_knn.knn_batched(query_points, reference_points, 3, 10000, 10000)
That said, torch.cdist
is already using lower level kernels, so you are unlikely to see any benefit from scripting. If you need to push KNN performance, consider looking into HNSW indexing with libraries like HNSWlib, FAISS, or similar. These are not torchscript compatible, but will give much better performance.
Upvotes: 0