Exploring
Exploring

Reputation: 3439

How to find closest embedding vectors?

I have 100K known embedding i.e.

[emb_1, emb_2, ..., emb_100000]

Each of this embedding is derived from GPT-3 sentence embedding with dimension 2048.

My task is given an embedding(embedding_new) find the closest 10 embedding from the above 100k embedding.

The way I am approaching this problem is brute force.

Every time a query asks to find the closest embeddings, I compare embedding_new with [emb_1, emb_2, ..., emb_100000] and get the similarity score.

Then I do quicksort of the similarity score to get the top 10 closest embedding.

Alternatively, I have also thought about using Faiss.

Is there a better way to achieve this?

Upvotes: 2

Views: 3266

Answers (2)

elkbrs
elkbrs

Reputation: 121

Using you own idea, just make sure that the embeddings are in a matrix form, you can easily use numpy for this. This is computed in linear time (in num. of embeddings) and should be fast.

import numpy as np
k = 10 # k best embeddings
emb_mat = np.stack([emb_1, emb_2, ..., emb_100000])
scores = np.dot(emb_mat, embedding_new)
best_k_ind = np.argpartition(scores, k)[-k:] 
top_k_emb = emb_mat[best_k_ind]   

The 10 best embeddings will be found in top_k_emb. For a general solution inside a software project you might consider Faiss by Facebook Research. An example for using Faiss:

d = 2048  # dimensionality of your embedding data
k = 10  # number of nearest neighbors to return
index = faiss.IndexFlatIP(d)
emb_list = [emb_1, emb_2, ..., emb_100000]
index.add(emb_list)
D, I = index.search(embedding_new, k)

You can use IndexFlatIP for inner product similarity, or indexFlatL2 for Euclidian\L2-norm distance. In order to bypass memory issues (data>1M) refer to this great infographic Faiss cheat sheet at slide num. 7

Upvotes: 0

Exploring
Exploring

Reputation: 3439

I found a solution using Vector Database Lite (VDBLITE)

VDBLITE here: https://pypi.org/project/vdblite/

import vdblite
from time import time
from uuid import uuid4
import sys
from pprint import pprint as pp


if __name__ == '__main__':
    vdb = vdblite.Vdb()
    dimension = 12    # dimensions of each vector                         
    n = 200    # number of vectors                   
    np.random.seed(1)             
    db_vectors = np.random.random((n, dimension)).astype('float32')
    print(db_vectors[0])
    for vector in db_vectors:
        info = {'vector': vector, 'time': time(), 'uuid': str(uuid4())}
        vdb.add(info)
    vdb.details()
    results = vdb.search(db_vectors[10])
    pp(results)

Looks like it uses FAISS behind the scene.

Upvotes: 0

Related Questions