YC-Chen commited on
Commit
b45f299
1 Parent(s): e6cbc32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +204 -1
app.py CHANGED
@@ -1,3 +1,206 @@
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- gr.load("models/MediaTek-Research/Breeze-7B-Instruct-v0.1").launch()
 
 
1
+ import os
2
+
3
  import gradio as gr
4
+ from transformers import AutoTokenizer
5
+
6
+
7
+ DESCRIPTION = """
8
+
9
+ """
10
+
11
+ LICENSE = """
12
+
13
+ """
14
+
15
+ DEFAULT_SYSTEM_PROMPT = ""
16
+
17
+ API_URL = os.environ.get("API_URL")
18
+ TOKEN = os.environ.get("TOKEN")
19
+
20
+ HEADER = {
21
+ "accept": "application/json",
22
+ "Authorization": f"Bearer {TOKEN}", "Content-Type": "application/json",
23
+ }
24
+
25
+ MODEL_NAME="breeze-7b-instruct-v01"
26
+ TEMPERATURE=1
27
+ MAX_TOKENS=16
28
+ TOP_P=0
29
+ PRESENCE_PENALTY=0
30
+ FREQUENCY_PENALTY=0
31
+
32
+
33
+ eos_token = "</s>"
34
+ MAX_MAX_NEW_TOKENS = 4096
35
+ DEFAULT_MAX_NEW_TOKENS = 1536
36
+
37
+ max_prompt_length = 8192 - MAX_MAX_NEW_TOKENS - 10
38
+
39
+ model_name = "MediaTek-Research/Breeze-7B-Instruct-v0.1"
40
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
41
+
42
+ with gr.Blocks() as demo:
43
+ gr.Markdown(DESCRIPTION)
44
+
45
+ chatbot = gr.Chatbot()
46
+ with gr.Row():
47
+ msg = gr.Textbox(
48
+ container=False,
49
+ show_label=False,
50
+ placeholder='Type a message...',
51
+ scale=10,
52
+ )
53
+ submit_button = gr.Button('Submit',
54
+ variant='primary',
55
+ scale=1,
56
+ min_width=0)
57
+
58
+ with gr.Row():
59
+ retry_button = gr.Button('🔄 Retry', variant='secondary')
60
+ undo_button = gr.Button('↩️ Undo', variant='secondary')
61
+ clear = gr.Button('🗑️ Clear', variant='secondary')
62
+
63
+ saved_input = gr.State()
64
+
65
+ with gr.Accordion(label='Advanced options', open=False):
66
+ system_prompt = gr.Textbox(label='System prompt',
67
+ value=DEFAULT_SYSTEM_PROMPT,
68
+ lines=6)
69
+ max_new_tokens = gr.Slider(
70
+ label='Max new tokens',
71
+ minimum=1,
72
+ maximum=MAX_MAX_NEW_TOKENS,
73
+ step=1,
74
+ value=DEFAULT_MAX_NEW_TOKENS,
75
+ )
76
+ temperature = gr.Slider(
77
+ label='Temperature',
78
+ minimum=0.1,
79
+ maximum=1.0,
80
+ step=0.1,
81
+ value=0.3,
82
+ )
83
+ top_p = gr.Slider(
84
+ label='Top-p (nucleus sampling)',
85
+ minimum=0.05,
86
+ maximum=1.0,
87
+ step=0.05,
88
+ value=0.95,
89
+ )
90
+ top_k = gr.Slider(
91
+ label='Top-k',
92
+ minimum=1,
93
+ maximum=1000,
94
+ step=1,
95
+ value=50,
96
+ )
97
+
98
+ def user(user_message, history):
99
+ return "", history + [[user_message, None]]
100
+
101
+
102
+ def bot(history, max_new_tokens, temperature, top_p, top_k, system_prompt):
103
+ data = {
104
+ "model": MODEL_NAME,
105
+ "messages": str(message),
106
+ "temperature": TEMPERATURE,
107
+ "n": 1,
108
+ "max_tokens": MAX_TOKENS,
109
+ "stop": "",
110
+ "top_p": TOP_P,
111
+ "logprobs": 0,
112
+ "echo": False,
113
+ "presence_penalty": PRESENCE_PENALTY,
114
+ "frequency_penalty": FREQUENCY_PENALTY,
115
+ }
116
+
117
+ outputs = requests.post(url, headers=headers, data=json.dumps(data)).json()
118
+ return outputs
119
+
120
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
121
+ fn=bot,
122
+ inputs=[
123
+ chatbot,
124
+ max_new_tokens,
125
+ temperature,
126
+ top_p,
127
+ top_k,
128
+ system_prompt,
129
+ ],
130
+ outputs=chatbot
131
+ )
132
+ submit_button.click(
133
+ user, [msg, chatbot], [msg, chatbot], queue=False
134
+ ).then(
135
+ fn=bot,
136
+ inputs=[
137
+ chatbot,
138
+ max_new_tokens,
139
+ temperature,
140
+ top_p,
141
+ top_k,
142
+ system_prompt,
143
+ ],
144
+ outputs=chatbot
145
+ )
146
+
147
+
148
+ def delete_prev_fn(
149
+ history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
150
+ try:
151
+ message, _ = history.pop()
152
+ except IndexError:
153
+ message = ''
154
+ return history, message or ''
155
+
156
+
157
+ def display_input(message: str,
158
+ history: list[tuple[str, str]]) -> list[tuple[str, str]]:
159
+ history.append((message, ''))
160
+ return history
161
+
162
+ retry_button.click(
163
+ fn=delete_prev_fn,
164
+ inputs=chatbot,
165
+ outputs=[chatbot, saved_input],
166
+ api_name=False,
167
+ queue=False,
168
+ ).then(
169
+ fn=display_input,
170
+ inputs=[saved_input, chatbot],
171
+ outputs=chatbot,
172
+ api_name=False,
173
+ queue=False,
174
+ ).then(
175
+ fn=bot,
176
+ inputs=[
177
+ chatbot,
178
+ max_new_tokens,
179
+ temperature,
180
+ top_p,
181
+ top_k,
182
+ system_prompt,
183
+ ],
184
+ outputs=chatbot,
185
+ )
186
+
187
+ undo_button.click(
188
+ fn=delete_prev_fn,
189
+ inputs=chatbot,
190
+ outputs=[chatbot, saved_input],
191
+ api_name=False,
192
+ queue=False,
193
+ ).then(
194
+ fn=lambda x: x,
195
+ inputs=[saved_input],
196
+ outputs=msg,
197
+ api_name=False,
198
+ queue=False,
199
+ )
200
+
201
+ clear.click(lambda: None, None, chatbot, queue=False)
202
+
203
+ gr.Markdown(LICENSE)
204
 
205
+ demo.queue(concurrency_count=4, max_size=128)
206
+ demo.launch()