LLM-api / apis /chat_api.py
ka1kuk's picture
Update apis/chat_api.py
125cf33 verified
raw
history blame
No virus
10.5 kB
import argparse
import os
import sys
import time
import uvicorn
import requests
import asyncio
from pathlib import Path
from fastapi import FastAPI, Depends, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, Field
from typing import Union, List
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
from utils.logger import logger
from networks.message_streamer import MessageStreamer
from messagers.message_composer import MessageComposer
from mocks.stream_chat_mocker import stream_chat_mock
from fastapi.middleware.cors import CORSMiddleware
class ChatAPIApp:
def __init__(self):
self.app = FastAPI(
docs_url="/",
title="HuggingFace LLM API",
swagger_ui_parameters={"defaultModelsExpandDepth": -1},
version="1.0",
)
self.setup_routes()
self.app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # You can specify specific origins here
allow_credentials=True,
allow_methods=["*"], # Or specify just the methods you need: ["GET", "POST"]
allow_headers=["*"], # Or specify headers you need
)
def get_available_models(self):
# https://platform.openai.com/docs/api-reference/models/list
# ANCHOR[id=available-models]: Available models
current_time = int(time.time())
self.available_models = {
"object": "list",
"data": [
{
"id": "mixtral-8x7b",
"description": "[mistralai/Mixtral-8x7B-Instruct-v0.1]: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1",
"object": "model",
"created": current_time,
"owned_by": "mistralai",
},
{
"id": "mistral-7b",
"description": "[mistralai/Mistral-7B-Instruct-v0.2]: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
"object": "model",
"created": current_time,
"owned_by": "mistralai",
},
{
"id": "nous-mixtral-8x7b",
"description": "[NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO]: https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
"object": "model",
"created": current_time,
"owned_by": "NousResearch",
},
{
"id": "gemma-7b",
"description": "[google/gemma-7b-it]: https://huggingface.co/google/gemma-7b-it",
"object": "model",
"created": current_time,
"owned_by": "Google",
},
{
"id": "codellama-7b",
"description": "[codellama/CodeLlama-7b-hf]: https://huggingface.co/codellama/CodeLlama-7b-hf",
"object": "model",
"created": current_time,
"owned_by": "codellama",
},
{
"id": "bert-base-uncased",
"description": "[google-bert/bert-base-uncased]: https://huggingface.co/google-bert/bert-base-uncased",
"object": "embedding",
"created": current_time,
"owned_by": "google",
},
],
}
return self.available_models
def extract_api_key(
credentials: HTTPAuthorizationCredentials = Depends(
HTTPBearer(auto_error=False)
),
):
api_key = None
if credentials:
api_key = credentials.credentials
else:
api_key = os.getenv("HF_TOKEN")
if api_key:
if api_key.startswith("hf_"):
return api_key
else:
logger.warn(f"Invalid HF Token!")
else:
logger.warn("Not provide HF Token!")
return None
class QueryRequest(BaseModel):
texts: List[str]
model_name: str = Field(..., example="bert-base-uncased")
api_key: str = Field(..., example="your_hf_api_key_here")
class ChatCompletionsPostItem(BaseModel):
model: str = Field(
default="mixtral-8x7b",
description="(str) `mixtral-8x7b`",
)
messages: list = Field(
default=[{"role": "user", "content": "Hello, who are you?"}],
description="(list) Messages",
)
temperature: Union[float, None] = Field(
default=0.5,
description="(float) Temperature",
)
top_p: Union[float, None] = Field(
default=0.95,
description="(float) top p",
)
max_tokens: Union[int, None] = Field(
default=-1,
description="(int) Max tokens",
)
use_cache: bool = Field(
default=False,
description="(bool) Use cache",
)
stream: bool = Field(
default=False,
description="(bool) Stream",
)
def chat_completions(
self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
):
streamer = MessageStreamer(model=item.model)
composer = MessageComposer(model=item.model)
composer.merge(messages=item.messages)
# streamer.chat = stream_chat_mock
stream_response = streamer.chat_response(
prompt=composer.merged_str,
temperature=item.temperature,
top_p=item.top_p,
max_new_tokens=item.max_tokens,
api_key=api_key,
use_cache=item.use_cache,
)
if item.stream:
event_source_response = EventSourceResponse(
streamer.chat_return_generator(stream_response),
media_type="text/event-stream",
ping=2000,
ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}),
)
return event_source_response
else:
data_response = streamer.chat_return_dict(stream_response)
return data_response
async def chat_embedding(self, texts, model_name, api_key):
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name}"
headers = {"Authorization": f"Bearer {api_key}"}
response = requests.post(api_url, headers=headers, json={"inputs": texts})
result = response.json()
if isinstance(result, list) and len(result) > 0 and isinstance(result[0], list):
return result
elif "error" in result:
raise RuntimeError("The model is currently loading, please re-run the query.")
else:
raise RuntimeError("Unexpected response format.")
async def embedding(self, request: QueryRequest):
try:
for attempt in range(3): # Retry logic
try:
embeddings = await self.chat_embedding(request.texts, request.model_name, request.api_key)
data = [
{"object": "embedding", "index": i, "embedding": embedding}
for i, embedding in enumerate(embeddings)
]
return {
"object": "list",
"data": data,
"model": request.model_name,
"usage": {"prompt_tokens": len(request.texts), "total_tokens": len(request.texts)}
}
except RuntimeError as e:
if attempt < 2: # Don't sleep on the last attempt
await asyncio.sleep(10) # Delay for the retry
raise HTTPException(status_code=503, detail="The model is currently loading, please try again later.")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def setup_routes(self):
for prefix in ["", "/v1", "/api", "/api/v1"]:
if prefix in ["/api/v1"]:
include_in_schema = True
else:
include_in_schema = False
self.app.get(
prefix + "/models",
summary="Get available models",
include_in_schema=include_in_schema,
)(self.get_available_models)
self.app.post(
prefix + "/chat/completions",
summary="Chat completions in conversation session",
include_in_schema=include_in_schema,
)(self.chat_completions)
if prefix in ["/v1"]:
include_in_schema = True
else:
include_in_schema = False
self.app.post(
prefix + "/embedding", # Use the specific prefix for this route
summary="Generate embeddings for the given texts",
include_in_schema=include_in_schema,
response_model=List # Adapt based on your actual response model
)(self.embedding)
class ArgParser(argparse.ArgumentParser):
def __init__(self, *args, **kwargs):
super(ArgParser, self).__init__(*args, **kwargs)
self.add_argument(
"-s",
"--server",
type=str,
default="0.0.0.0",
help="Server IP for HF LLM Chat API",
)
self.add_argument(
"-p",
"--port",
type=int,
default=23333,
help="Server Port for HF LLM Chat API",
)
self.add_argument(
"-d",
"--dev",
default=False,
action="store_true",
help="Run in dev mode",
)
self.args = self.parse_args(sys.argv[1:])
app = ChatAPIApp().app
if __name__ == "__main__":
args = ArgParser().args
if args.dev:
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True)
else:
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False)
# python -m apis.chat_api # [Docker] on product mode
# python -m apis.chat_api -d # [Dev] on develop mode