Phys
Phys

Reputation: 518

Script for streaming Mistral-7B LLM output only streams on server side. Client gets full output

I designed a remote server - client pipeline, which is supposed to load the model on the server and stream the output of the model. At the moment, the output is correctly streamed, but only inside the server, meaning that the output displayed on the client side is the full one. What is the issue?

Below the code to start the model on the server

from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import torch
from pydantic import BaseModel
import sys
import os
from queue import Queue
from threading import Thread

from inference import set_model, process_audio_streaming

app = FastAPI()

# Load the tokenizer and model
model, tokenizer, args = set_model() # replace with your script

class PredictionRequest(BaseModel):
    raw_feedback: str

class PredictionResponse(BaseModel):
    prediction: str

class CustomTextStreamer(TextStreamer):
    def __init__(self, tokenizer):
        super().__init__(tokenizer)
        self.queue = Queue()

    def on_text(self, text: str, **kwargs):
        self.queue.put(text)

    def get_generated_text(self):
        while True:
            text = self.queue.get()
            if text is None:
                break
            yield text

def generate_text(prompt: str, max_new_tokens: int = 256):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    streamer = CustomTextStreamer(tokenizer)
    
    def generate():
        model.generate(inputs['input_ids'], streamer=streamer, max_new_tokens=max_new_tokens)
        streamer.queue.put(None)  # Signal the end of generation

    generation_thread = Thread(target=generate)
    generation_thread.start()

    return streamer.get_generated_text()


@app.post("/generate-text")
async def generate_text_endpoint(request: Request):
    body = await request.json()
    raw_feedback = body.get("raw_feedback")
    
    if raw_feedback is None is None:
        raise HTTPException(status_code=400, detail="raw_feedback is required")

    # Process the raw feedback and accuracy
    # raw_feedback = process_audio_streaming(raw_feedback, args)

    # Define the prompt using the processed feedback and accuracy
    prompt = (
        "tet prompt"
    )

    # Stream the text as it's generated
    return StreamingResponse(generate_text(prompt), media_type="text/event-stream") #text/plain


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

And the one to call the inference from the client

import requests

# URL of the FastAPI endpoint
url = 'http://localhost:8000/generate-text'

# Data to be sent to the endpoint
data = {}

try:
    with requests.post(url, json=data, stream=True) as r:
        r.raise_for_status()  # Ensure we catch HTTP errors

        for chunk in r.iter_content(chunk_size=1024):
            if chunk:
                print(chunk.decode('utf-8'), end='', flush=True)
except requests.RequestException as e:
    print(f"An error occurred: {e}")

Upvotes: 0

Views: 137

Answers (0)

Related Questions