ka1kuk commited on
Commit
73c21f4
1 Parent(s): 03c75c9

Update apis/chat_api.py

Browse files
Files changed (1) hide show
  1. apis/chat_api.py +25 -6
apis/chat_api.py CHANGED
@@ -97,19 +97,29 @@ class ChatAPIApp:
97
  description="(list) Messages",
98
  )
99
  temperature: Union[float, None] = Field(
100
- default=0,
101
  description="(float) Temperature",
102
  )
 
 
 
 
103
  max_tokens: Union[int, None] = Field(
104
  default=-1,
105
  description="(int) Max tokens",
106
  )
107
- stream: bool = Field(
108
  default=False,
 
 
 
 
109
  description="(bool) Stream",
110
  )
111
 
112
- def chat_completions(self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)):
 
 
113
  streamer = MessageStreamer(model=item.model)
114
  composer = MessageComposer(model=item.model)
115
  composer.merge(messages=item.messages)
@@ -118,8 +128,10 @@ class ChatAPIApp:
118
  stream_response = streamer.chat_response(
119
  prompt=composer.merged_str,
120
  temperature=item.temperature,
 
121
  max_new_tokens=item.max_tokens,
122
  api_key=api_key,
 
123
  )
124
  if item.stream:
125
  event_source_response = EventSourceResponse(
@@ -133,7 +145,16 @@ class ChatAPIApp:
133
  data_response = streamer.chat_return_dict(stream_response)
134
  return data_response
135
 
136
- def setup_routes(self):
 
 
 
 
 
 
 
 
 
137
  for prefix in ["", "/v1", "/api", "/api/v1"]:
138
  if prefix in ["/api/v1"]:
139
  include_in_schema = True
@@ -153,8 +174,6 @@ class ChatAPIApp:
153
  )(self.chat_completions)
154
 
155
 
156
-
157
-
158
  class ArgParser(argparse.ArgumentParser):
159
  def __init__(self, *args, **kwargs):
160
  super(ArgParser, self).__init__(*args, **kwargs)
 
97
  description="(list) Messages",
98
  )
99
  temperature: Union[float, None] = Field(
100
+ default=0.5,
101
  description="(float) Temperature",
102
  )
103
+ top_p: Union[float, None] = Field(
104
+ default=0.95,
105
+ description="(float) top p",
106
+ )
107
  max_tokens: Union[int, None] = Field(
108
  default=-1,
109
  description="(int) Max tokens",
110
  )
111
+ use_cache: bool = Field(
112
  default=False,
113
+ description="(bool) Use cache",
114
+ )
115
+ stream: bool = Field(
116
+ default=True,
117
  description="(bool) Stream",
118
  )
119
 
120
+ def chat_completions(
121
+ self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
122
+ ):
123
  streamer = MessageStreamer(model=item.model)
124
  composer = MessageComposer(model=item.model)
125
  composer.merge(messages=item.messages)
 
128
  stream_response = streamer.chat_response(
129
  prompt=composer.merged_str,
130
  temperature=item.temperature,
131
+ top_p=item.top_p,
132
  max_new_tokens=item.max_tokens,
133
  api_key=api_key,
134
+ use_cache=item.use_cache,
135
  )
136
  if item.stream:
137
  event_source_response = EventSourceResponse(
 
145
  data_response = streamer.chat_return_dict(stream_response)
146
  return data_response
147
 
148
+ def get_readme(self):
149
+ readme_path = Path(__file__).parents[1] / "README.md"
150
+ with open(readme_path, "r", encoding="utf-8") as rf:
151
+ readme_str = rf.read()
152
+ readme_html = markdown2.markdown(
153
+ readme_str, extras=["table", "fenced-code-blocks", "highlightjs-lang"]
154
+ )
155
+ return readme_html
156
+
157
+ def setup_routes(self):
158
  for prefix in ["", "/v1", "/api", "/api/v1"]:
159
  if prefix in ["/api/v1"]:
160
  include_in_schema = True
 
174
  )(self.chat_completions)
175
 
176
 
 
 
177
  class ArgParser(argparse.ArgumentParser):
178
  def __init__(self, *args, **kwargs):
179
  super(ArgParser, self).__init__(*args, **kwargs)