LLM-api / apis /chat_api.py
ka1kuk's picture
Update apis/chat_api.py
b10be80 verified
raw
history blame contribute delete
No virus
10.2 kB
import argparse
import os
import sys
import time
import uvicorn
import requests
import asyncio
import logging
from pathlib import Path
from fastapi import FastAPI, Depends, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, Field
from typing import Union, List, Dict, Any
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
from utils.logger import logger
from networks.message_streamer import MessageStreamer
from messagers.message_composer import MessageComposer
from mocks.stream_chat_mocker import stream_chat_mock
from fastapi.middleware.cors import CORSMiddleware
class EmbeddingResponseItem(BaseModel):
object: str = "embedding"
index: int
embedding: List[List[float]]
class EmbeddingResponse(BaseModel):
object: str = "list"
data: List[EmbeddingResponseItem]
model: str
usage: Dict[str, Any]
class ChatAPIApp:
def __init__(self):
self.app = FastAPI(
docs_url="/",
title="HuggingFace LLM API",
swagger_ui_parameters={"defaultModelsExpandDepth": -1},
version="1.0",
)
self.setup_routes()
self.app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # You can specify specific origins here
allow_credentials=True,
allow_methods=["*"], # Or specify just the methods you need: ["GET", "POST"]
allow_headers=["*"], # Or specify headers you need
)
def get_available_models(self):
# https://platform.openai.com/docs/api-reference/models/list
# ANCHOR[id=available-models]: Available models
current_time = int(time.time())
self.available_models = {
"object": "list",
"data": [
{
"id": "mixtral-8x7b",
"description": "[mistralai/Mixtral-8x7B-Instruct-v0.1]: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1",
"object": "model",
"created": current_time,
"owned_by": "mistralai",
},
{
"id": "mistral-7b",
"description": "[mistralai/Mistral-7B-Instruct-v0.2]: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
"object": "model",
"created": current_time,
"owned_by": "mistralai",
},
{
"id": "nous-mixtral-8x7b",
"description": "[NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO]: https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
"object": "model",
"created": current_time,
"owned_by": "NousResearch",
},
{
"id": "gemma-7b",
"description": "[google/gemma-7b-it]: https://huggingface.co/google/gemma-7b-it",
"object": "model",
"created": current_time,
"owned_by": "Google",
},
{
"id": "codellama-7b",
"description": "[codellama/CodeLlama-7b-hf]: https://huggingface.co/codellama/CodeLlama-7b-hf",
"object": "model",
"created": current_time,
"owned_by": "codellama",
},
{
"id": "bert-base-uncased",
"description": "[google-bert/bert-base-uncased]: https://huggingface.co/google-bert/bert-base-uncased",
"object": "embedding",
"created": current_time,
"owned_by": "google",
},
],
}
return self.available_models
def extract_api_key(
credentials: HTTPAuthorizationCredentials = Depends(
HTTPBearer(auto_error=False)
),
):
api_key = None
if credentials:
api_key = credentials.credentials
else:
api_key = os.getenv("HF_TOKEN")
if api_key:
if api_key.startswith("hf_"):
return api_key
else:
logger.warn(f"Invalid HF Token!")
else:
logger.warn("Not provide HF Token!")
return None
class QueryRequest(BaseModel):
input: str
model: str = Field(default="bert-base-uncased")
encoding_format: str
class ChatCompletionsPostItem(BaseModel):
model: str = Field(
default="mixtral-8x7b",
description="(str) `mixtral-8x7b`",
)
messages: list = Field(
default=[{"role": "user", "content": "Hello, who are you?"}],
description="(list) Messages",
)
temperature: Union[float, None] = Field(
default=0.5,
description="(float) Temperature",
)
top_p: Union[float, None] = Field(
default=0.95,
description="(float) top p",
)
max_tokens: Union[int, None] = Field(
default=-1,
description="(int) Max tokens",
)
use_cache: bool = Field(
default=False,
description="(bool) Use cache",
)
stream: bool = Field(
default=False,
description="(bool) Stream",
)
def chat_completions(
self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
):
streamer = MessageStreamer(model=item.model)
composer = MessageComposer(model=item.model)
composer.merge(messages=item.messages)
# streamer.chat = stream_chat_mock
stream_response = streamer.chat_response(
prompt=composer.merged_str,
temperature=item.temperature,
top_p=item.top_p,
max_new_tokens=item.max_tokens,
api_key=api_key,
use_cache=item.use_cache,
)
if item.stream:
event_source_response = EventSourceResponse(
streamer.chat_return_generator(stream_response),
media_type="text/event-stream",
ping=2000,
ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}),
)
return event_source_response
else:
data_response = streamer.chat_return_dict(stream_response)
return data_response
async def embedding(self, request: QueryRequest, api_key: str = Depends(extract_api_key)):
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{request.model}"
headers = {"Authorization": f"Bearer {api_key}"}
response = await requests.post(api_url, headers=headers, json={"inputs": request.input})
result = response.json()
if "error" in result:
logging.error(f"Error from Hugging Face API: {result['error']}")
error_detail = result.get('error', 'No detailed error message provided.')
raise HTTPException(status_code=503, detail=f"The model is currently loading, please re-run the query. Detail: {error_detail}")
if isinstance(result, list) and len(result) > 0 and isinstance(result[0], list):
flattened_embeddings = [item for sublist in result for item in sublist] # Flatten list of lists
data = [{"object": "embedding", "index": i, "embedding": embedding} for i, embedding in enumerate(flattened_embeddings)]
return EmbeddingResponse(
object="list",
data=data,
model=request.model,
usage={"prompt_tokens": len(request.input), "total_tokens": len(request.input)}
)
else:
logging.error(f"Unexpected response format: {result}")
raise HTTPException(status_code=500, detail="Unexpected response format.")
def setup_routes(self):
for prefix in ["", "/v1", "/api", "/api/v1"]:
if prefix in ["/api/v1"]:
include_in_schema = True
else:
include_in_schema = False
self.app.get(
prefix + "/models",
summary="Get available models",
include_in_schema=include_in_schema,
)(self.get_available_models)
self.app.post(
prefix + "/chat/completions",
summary="Chat completions in conversation session",
include_in_schema=include_in_schema,
)(self.chat_completions)
self.app.post(
prefix + "/embeddings", # Use the specific prefix for this route
summary="Generate embeddings for the given texts",
include_in_schema=include_in_schema,
response_model=EmbeddingResponse # Adapt based on your actual response model
)(self.embedding)
class ArgParser(argparse.ArgumentParser):
def __init__(self, *args, **kwargs):
super(ArgParser, self).__init__(*args, **kwargs)
self.add_argument(
"-s",
"--server",
type=str,
default="0.0.0.0",
help="Server IP for HF LLM Chat API",
)
self.add_argument(
"-p",
"--port",
type=int,
default=23333,
help="Server Port for HF LLM Chat API",
)
self.add_argument(
"-d",
"--dev",
default=False,
action="store_true",
help="Run in dev mode",
)
self.args = self.parse_args(sys.argv[1:])
app = ChatAPIApp().app
if __name__ == "__main__":
args = ArgParser().args
if args.dev:
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True)
else:
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False)
# python -m apis.chat_api # [Docker] on product mode
# python -m apis.chat_api -d # [Dev] on develop mode