ka1kuk commited on
Commit
829d976
1 Parent(s): e994f7a

Update apis/chat_api.py

Browse files
Files changed (1) hide show
  1. apis/chat_api.py +40 -24
apis/chat_api.py CHANGED
@@ -2,11 +2,11 @@ import argparse
2
  import os
3
  import sys
4
  import uvicorn
5
- import traceback
6
 
7
  from fastapi import FastAPI, Depends
8
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
9
  from pydantic import BaseModel, Field
 
10
  from sse_starlette.sse import EventSourceResponse, ServerSentEvent
11
  from utils.logger import logger
12
  from networks.message_streamer import MessageStreamer
@@ -25,21 +25,27 @@ class ChatAPIApp:
25
  self.setup_routes()
26
 
27
  def get_available_models(self):
 
28
  # ANCHOR[id=available-models]: Available models
29
- self.available_models = [
30
- {
31
- "id": "mixtral-8x7b",
32
- "description": "[mistralai/Mixtral-8x7B-Instruct-v0.1]: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1",
33
- },
34
- {
35
- "id": "mistral-7b",
36
- "description": "[mistralai/Mistral-7B-Instruct-v0.2]: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2"
37
- },
38
- {
39
- "id": "openchat-3.5",
40
- "description": "[openchat/openchat-3.5-1210]: https://huggingface.co/openchat/openchat-3.5-1210",
41
- },
42
- ]
 
 
 
 
 
43
  return self.available_models
44
 
45
  def extract_api_key(
@@ -47,10 +53,20 @@ class ChatAPIApp:
47
  HTTPBearer(auto_error=False)
48
  ),
49
  ):
 
50
  if credentials:
51
- return credentials.credentials
52
  else:
53
- return os.getenv("HF_TOKEN") or None
 
 
 
 
 
 
 
 
 
54
 
55
  class ChatCompletionsPostItem(BaseModel):
56
  model: str = Field(
@@ -61,16 +77,16 @@ class ChatAPIApp:
61
  default=[{"role": "user", "content": "Hello, who are you?"}],
62
  description="(list) Messages",
63
  )
64
- temperature: float = Field(
65
- default=0.01,
66
  description="(float) Temperature",
67
  )
68
- max_tokens: int = Field(
69
- default=4096,
70
  description="(int) Max tokens",
71
  )
72
  stream: bool = Field(
73
- default=False,
74
  description="(bool) Stream",
75
  )
76
 
@@ -101,7 +117,7 @@ class ChatAPIApp:
101
  return data_response
102
 
103
  def setup_routes(self):
104
- for prefix in ["", "/v1"]:
105
  self.app.get(
106
  prefix + "/models",
107
  summary="Get available models",
@@ -153,4 +169,4 @@ if __name__ == "__main__":
153
  uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False)
154
 
155
  # python -m apis.chat_api # [Docker] on product mode
156
- # python -m apis.chat_api -d # [Dev] on develop mode
 
2
  import os
3
  import sys
4
  import uvicorn
 
5
 
6
  from fastapi import FastAPI, Depends
7
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
8
  from pydantic import BaseModel, Field
9
+ from typing import Union
10
  from sse_starlette.sse import EventSourceResponse, ServerSentEvent
11
  from utils.logger import logger
12
  from networks.message_streamer import MessageStreamer
 
25
  self.setup_routes()
26
 
27
  def get_available_models(self):
28
+ # https://platform.openai.com/docs/api-reference/models/list
29
  # ANCHOR[id=available-models]: Available models
30
+ self.available_models = {
31
+ "object": "list",
32
+ "data": [
33
+ {
34
+ "id": "mixtral-8x7b",
35
+ "description": "[mistralai/Mixtral-8x7B-Instruct-v0.1]: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1",
36
+ "object": "model",
37
+ "created": 1700000000,
38
+ "owned_by": "mistralai",
39
+ },
40
+ {
41
+ "id": "mistral-7b",
42
+ "description": "[mistralai/Mistral-7B-Instruct-v0.2]: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
43
+ "object": "model",
44
+ "created": 1700000000,
45
+ "owned_by": "mistralai",
46
+ },
47
+ ],
48
+ }
49
  return self.available_models
50
 
51
  def extract_api_key(
 
53
  HTTPBearer(auto_error=False)
54
  ),
55
  ):
56
+ api_key = None
57
  if credentials:
58
+ api_key = credentials.credentials
59
  else:
60
+ api_key = os.getenv("HF_TOKEN")
61
+
62
+ if api_key:
63
+ if api_key.startswith("hf_"):
64
+ return api_key
65
+ else:
66
+ logger.warn(f"Invalid HF Token!")
67
+ else:
68
+ logger.warn("Not provide HF Token!")
69
+ return None
70
 
71
  class ChatCompletionsPostItem(BaseModel):
72
  model: str = Field(
 
77
  default=[{"role": "user", "content": "Hello, who are you?"}],
78
  description="(list) Messages",
79
  )
80
+ temperature: Union[float, None] = Field(
81
+ default=0,
82
  description="(float) Temperature",
83
  )
84
+ max_tokens: Union[int, None] = Field(
85
+ default=-1,
86
  description="(int) Max tokens",
87
  )
88
  stream: bool = Field(
89
+ default=True,
90
  description="(bool) Stream",
91
  )
92
 
 
117
  return data_response
118
 
119
  def setup_routes(self):
120
+ for prefix in ["", "/v1", "/api", "/api/v1"]:
121
  self.app.get(
122
  prefix + "/models",
123
  summary="Get available models",
 
169
  uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False)
170
 
171
  # python -m apis.chat_api # [Docker] on product mode
172
+ # python -m apis.chat_api -d # [Dev] on develop mode