ka1kuk commited on
Commit
0ec6f70
β€’
1 Parent(s): 829d976

Update networks/message_streamer.py

Browse files
Files changed (1) hide show
  1. networks/message_streamer.py +52 -8
networks/message_streamer.py CHANGED
@@ -1,6 +1,7 @@
1
  import json
2
  import re
3
  import requests
 
4
  from messagers.message_outputer import OpenaiStreamOutputer
5
  from utils.logger import logger
6
  from utils.enver import enver
@@ -10,9 +11,11 @@ class MessageStreamer:
10
  MODEL_MAP = {
11
  "mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1", # 72.62, fast [Recommended]
12
  "mistral-7b": "mistralai/Mistral-7B-Instruct-v0.2", # 65.71, fast
13
- "openchat-3.5": "openchat/openchat-3.5-1210", # ??, fast
14
- # "zephyr-7b-alpha": "HuggingFaceH4/zephyr-7b-alpha", # 59.5, fast
15
- # "zephyr-7b-beta": "HuggingFaceH4/zephyr-7b-beta", # 61.95, slow
 
 
16
  "default": "mistralai/Mixtral-8x7B-Instruct-v0.1",
17
  }
18
  STOP_SEQUENCES_MAP = {
@@ -20,6 +23,12 @@ class MessageStreamer:
20
  "mistral-7b": "</s>",
21
  "openchat-3.5": "<|end_of_turn|>",
22
  }
 
 
 
 
 
 
23
 
24
  def __init__(self, model: str):
25
  if model in self.MODEL_MAP.keys():
@@ -28,19 +37,29 @@ class MessageStreamer:
28
  self.model = "default"
29
  self.model_fullname = self.MODEL_MAP[self.model]
30
  self.message_outputer = OpenaiStreamOutputer()
 
31
 
32
  def parse_line(self, line):
33
  line = line.decode("utf-8")
34
  line = re.sub(r"data:\s*", "", line)
35
  data = json.loads(line)
36
- content = data["token"]["text"]
 
 
 
37
  return content
38
 
 
 
 
 
 
 
39
  def chat_response(
40
  self,
41
  prompt: str = None,
42
- temperature: float = 0.01,
43
- max_new_tokens: int = 8192,
44
  api_key: str = None,
45
  ):
46
  # https://huggingface.co/docs/api-inference/detailed_parameters?code=curl
@@ -58,6 +77,25 @@ class MessageStreamer:
58
  )
59
  self.request_headers["Authorization"] = f"Bearer {api_key}"
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # References:
62
  # huggingface_hub/inference/_client.py:
63
  # class InferenceClient > def text_generation()
@@ -67,7 +105,7 @@ class MessageStreamer:
67
  self.request_body = {
68
  "inputs": prompt,
69
  "parameters": {
70
- "temperature": max(temperature, 0.01), # must be positive
71
  "max_new_tokens": max_new_tokens,
72
  "return_full_text": False,
73
  },
@@ -128,13 +166,17 @@ class MessageStreamer:
128
  if self.model in self.STOP_SEQUENCES_MAP.keys():
129
  final_content = final_content.replace(self.stop_sequences, "")
130
 
 
131
  final_output["choices"][0]["message"]["content"] = final_content
132
  return final_output
133
 
134
  def chat_return_generator(self, stream_response):
135
  is_finished = False
 
136
  for line in stream_response.iter_lines():
137
- if not line:
 
 
138
  continue
139
 
140
  content = self.parse_line(line)
@@ -145,6 +187,8 @@ class MessageStreamer:
145
  is_finished = True
146
  else:
147
  content_type = "Completions"
 
 
148
  logger.back(content, end="")
149
 
150
  output = self.message_outputer.output(
 
1
  import json
2
  import re
3
  import requests
4
+ from tiktoken import get_encoding as tiktoken_get_encoding
5
  from messagers.message_outputer import OpenaiStreamOutputer
6
  from utils.logger import logger
7
  from utils.enver import enver
 
11
  MODEL_MAP = {
12
  "mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1", # 72.62, fast [Recommended]
13
  "mistral-7b": "mistralai/Mistral-7B-Instruct-v0.2", # 65.71, fast
14
+ # "openchat-3.5": "openchat/openchat-3.5-1210", # 68.89, fast
15
+ # "zephyr-7b-beta": "HuggingFaceH4/zephyr-7b-beta", # ❌ Too Slow
16
+ # "llama-70b": "meta-llama/Llama-2-70b-chat-hf", # ❌ Require Pro User
17
+ # "codellama-34b": "codellama/CodeLlama-34b-Instruct-hf", # ❌ Low Score
18
+ # "falcon-180b": "tiiuae/falcon-180B-chat", # ❌ Require Pro User
19
  "default": "mistralai/Mixtral-8x7B-Instruct-v0.1",
20
  }
21
  STOP_SEQUENCES_MAP = {
 
23
  "mistral-7b": "</s>",
24
  "openchat-3.5": "<|end_of_turn|>",
25
  }
26
+ TOKEN_LIMIT_MAP = {
27
+ "mixtral-8x7b": 32768,
28
+ "mistral-7b": 32768,
29
+ "openchat-3.5": 8192,
30
+ }
31
+ TOKEN_RESERVED = 100
32
 
33
  def __init__(self, model: str):
34
  if model in self.MODEL_MAP.keys():
 
37
  self.model = "default"
38
  self.model_fullname = self.MODEL_MAP[self.model]
39
  self.message_outputer = OpenaiStreamOutputer()
40
+ self.tokenizer = tiktoken_get_encoding("cl100k_base")
41
 
42
  def parse_line(self, line):
43
  line = line.decode("utf-8")
44
  line = re.sub(r"data:\s*", "", line)
45
  data = json.loads(line)
46
+ try:
47
+ content = data["token"]["text"]
48
+ except:
49
+ logger.err(data)
50
  return content
51
 
52
+ def count_tokens(self, text):
53
+ tokens = self.tokenizer.encode(text)
54
+ token_count = len(tokens)
55
+ logger.note(f"Prompt Token Count: {token_count}")
56
+ return token_count
57
+
58
  def chat_response(
59
  self,
60
  prompt: str = None,
61
+ temperature: float = 0,
62
+ max_new_tokens: int = None,
63
  api_key: str = None,
64
  ):
65
  # https://huggingface.co/docs/api-inference/detailed_parameters?code=curl
 
77
  )
78
  self.request_headers["Authorization"] = f"Bearer {api_key}"
79
 
80
+ if temperature is None or temperature < 0:
81
+ temperature = 0.0
82
+ # temperature must be positive and <= 1 for HF LLM models
83
+ temperature = max(temperature, 0.01)
84
+ temperature = min(temperature, 1)
85
+
86
+ token_limit = int(
87
+ self.TOKEN_LIMIT_MAP[self.model]
88
+ - self.TOKEN_RESERVED
89
+ - self.count_tokens(prompt) * 1.35
90
+ )
91
+ if token_limit <= 0:
92
+ raise ValueError("Prompt exceeded token limit!")
93
+
94
+ if max_new_tokens is None or max_new_tokens <= 0:
95
+ max_new_tokens = token_limit
96
+ else:
97
+ max_new_tokens = min(max_new_tokens, token_limit)
98
+
99
  # References:
100
  # huggingface_hub/inference/_client.py:
101
  # class InferenceClient > def text_generation()
 
105
  self.request_body = {
106
  "inputs": prompt,
107
  "parameters": {
108
+ "temperature": temperature,
109
  "max_new_tokens": max_new_tokens,
110
  "return_full_text": False,
111
  },
 
166
  if self.model in self.STOP_SEQUENCES_MAP.keys():
167
  final_content = final_content.replace(self.stop_sequences, "")
168
 
169
+ final_content = final_content.strip()
170
  final_output["choices"][0]["message"]["content"] = final_content
171
  return final_output
172
 
173
  def chat_return_generator(self, stream_response):
174
  is_finished = False
175
+ line_count = 0
176
  for line in stream_response.iter_lines():
177
+ if line:
178
+ line_count += 1
179
+ else:
180
  continue
181
 
182
  content = self.parse_line(line)
 
187
  is_finished = True
188
  else:
189
  content_type = "Completions"
190
+ if line_count == 1:
191
+ content = content.lstrip()
192
  logger.back(content, end="")
193
 
194
  output = self.message_outputer.output(