Reputation: 518
I designed a remote server - client pipeline, which is supposed to load the model on the server and stream the output of the model. At the moment, the output is correctly streamed, but only inside the server, meaning that the output displayed on the client side is the full one. What is the issue?
Below the code to start the model on the server
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import torch
from pydantic import BaseModel
import sys
import os
from queue import Queue
from threading import Thread
from inference import set_model, process_audio_streaming
app = FastAPI()
# Load the tokenizer and model
model, tokenizer, args = set_model() # replace with your script
class PredictionRequest(BaseModel):
raw_feedback: str
class PredictionResponse(BaseModel):
prediction: str
class CustomTextStreamer(TextStreamer):
def __init__(self, tokenizer):
super().__init__(tokenizer)
self.queue = Queue()
def on_text(self, text: str, **kwargs):
self.queue.put(text)
def get_generated_text(self):
while True:
text = self.queue.get()
if text is None:
break
yield text
def generate_text(prompt: str, max_new_tokens: int = 256):
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
streamer = CustomTextStreamer(tokenizer)
def generate():
model.generate(inputs['input_ids'], streamer=streamer, max_new_tokens=max_new_tokens)
streamer.queue.put(None) # Signal the end of generation
generation_thread = Thread(target=generate)
generation_thread.start()
return streamer.get_generated_text()
@app.post("/generate-text")
async def generate_text_endpoint(request: Request):
body = await request.json()
raw_feedback = body.get("raw_feedback")
if raw_feedback is None is None:
raise HTTPException(status_code=400, detail="raw_feedback is required")
# Process the raw feedback and accuracy
# raw_feedback = process_audio_streaming(raw_feedback, args)
# Define the prompt using the processed feedback and accuracy
prompt = (
"tet prompt"
)
# Stream the text as it's generated
return StreamingResponse(generate_text(prompt), media_type="text/event-stream") #text/plain
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
And the one to call the inference from the client
import requests
# URL of the FastAPI endpoint
url = 'http://localhost:8000/generate-text'
# Data to be sent to the endpoint
data = {}
try:
with requests.post(url, json=data, stream=True) as r:
r.raise_for_status() # Ensure we catch HTTP errors
for chunk in r.iter_content(chunk_size=1024):
if chunk:
print(chunk.decode('utf-8'), end='', flush=True)
except requests.RequestException as e:
print(f"An error occurred: {e}")
Upvotes: 0
Views: 137