Arko
Arko

Reputation: 189

How to create an asynchronous embedding inference endpoint in python?

I am trying to create an asynchronous embedding endpoint in a python server that can concurrently handle many embedding requests.

If I create a normal synchronous endpoint, for example using FastAPI and FastEmbed, fastapi will just run the embedding function in a threadpool. I don't think this is good in this context since embedding is CPU blocking and should ideally be run in a process pool to achieve concurrency. I did try that but faced some serialisation issues.

Then I came across Text Embedding Inference repository by HuggingFace and in the source code I found the backends/python/server package.

Quoting from the main server.py file, it starts up a gRPC server with the newer asyncio api and this is the service class

class EmbeddingService(embed_pb2_grpc.EmbeddingServiceServicer):
    def __init__(self, model: Model):
        self.model = model
        # Force inference mode for the lifetime of EmbeddingService
        self._inference_mode_raii_guard = torch._C._InferenceMode(True)

    async def Health(self, request, context):
        if self.model.device.type == "cuda":
            torch.zeros((2, 2), device="cuda")
        return embed_pb2.HealthResponse()

    async def Embed(self, request, context):
        batch = self.model.batch_type.from_pb(request, self.model.device)

        embeddings = self.model.embed(batch)

        return embed_pb2.EmbedResponse(embeddings=embeddings)

    async def Predict(self, request, context):
        batch = self.model.batch_type.from_pb(request, self.model.device)

        scores = self.model.predict(batch)

        return embed_pb2.PredictResponse(scores=scores)

As can be seen, the methods are asynchronous. But then I saw the embed() function under models/default_model.py.

class DefaultModel(Model):
    def __init__(
        self, model_path: Path, device: torch.device, dtype: torch.dtype, pool: str
    ):
        model = AutoModel.from_pretrained(model_path).to(dtype).to(device)
        self.hidden_size = model.config.hidden_size
        self.pooling = Pooling(self.hidden_size, pooling_mode=pool)

        self.has_position_ids = (
            inspect.signature(model.forward).parameters.get("position_ids", None)
            is not None
        )
        self.has_token_type_ids = (
            inspect.signature(model.forward).parameters.get("token_type_ids", None)
            is not None
        )

        super(DefaultModel, self).__init__(model=model, dtype=dtype, device=device)

    @property
    def batch_type(self) -> Type[PaddedBatch]:
        return PaddedBatch

    @tracer.start_as_current_span("embed")
    def embed(self, batch: PaddedBatch) -> List[Embedding]:
        kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask}
        if self.has_token_type_ids:
            kwargs["token_type_ids"] = batch.token_type_ids
        if self.has_position_ids:
            kwargs["position_ids"] = batch.position_ids
        output = self.model(**kwargs)

        pooling_features = {
            "token_embeddings": output[0],
            "attention_mask": batch.attention_mask,
        }
        embedding = self.pooling.forward(pooling_features)["sentence_embedding"]

        cpu_results = embedding.view(-1).tolist()

        return [
            Embedding(
                values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size]
            )
            for i in range(len(batch))
        ]

    @tracer.start_as_current_span("predict")
    def predict(self, batch: PaddedBatch) -> List[Score]:
        pass

It's not clear to me if this achieves any concurrency.

The gRPC asyncio docs clearly state:

Making blocking function calls in coroutines or in the thread running event loop will block the event loop, potentially starving all RPCs in the process.

So what would be the best way to go about this? I know I can deploy TEI containers but I am looking for a bit more control. Should I explore a bit more on running embedding models in process pools or is there something else I am missing?

Upvotes: 0

Views: 66

Answers (0)

Related Questions