|
import os |
|
import requests |
|
import json |
|
|
|
import gradio as gr |
|
from transformers import AutoTokenizer |
|
|
|
|
|
DESCRIPTION = """ |
|
# Demo: Breeze-7B-Instruct-v0.1 |
|
|
|
Breeze-7B is a language model family that builds on top of [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.1), specifically intended for Traditional Chinese use. |
|
|
|
[Breeze-7B-Base](https://huggingface.co/MediaTek-Research/Breeze-7B-Base-v0.1) is the base model for the Breeze-7B series. |
|
It is suitable for use if you have substantial fine-tuning data to tune it for your specific use case. |
|
|
|
[Breeze-7B-Instruct](https://huggingface.co/MediaTek-Research/Breeze-7B-Instruct-v0.1) derives from the base model Breeze-7B-Base, making the resulting model amenable to be used as-is for commonly seen tasks. |
|
|
|
[Breeze-7B-Instruct-64k](https://huggingface.co/MediaTek-Research/Breeze-7B-Instruct-64k-v0.1) is a slightly modified version of |
|
Breeze-7B-Instruct to enable a 64k-token context length. Roughly speaking, that is equivalent to 88k Traditional Chinese characters. |
|
|
|
The current release version of Breeze-7B is v0.1. |
|
|
|
*A project by the members (in alphabetical order): Chan-Jan Hsu 許湛然, Chang-Le Liu 劉昶樂, Feng-Ting Liao 廖峰挺, Po-Chun Hsu 許博竣, Yi-Chang Chen 陳宜昌, and the supervisor Da-Shan Shiu 許大山.* |
|
|
|
**免責聲明: Breeze-7B-Instruct 和 Breeze-7B-Instruct-64k 並未針對問答進行安全保護,因此語言模型的任何回應不代表 MediaTek Research 立場。** |
|
""" |
|
|
|
LICENSE = """ |
|
|
|
""" |
|
|
|
DEFAULT_SYSTEM_PROMPT = "You are a helpful AI assistant built by MediaTek Research. The user you are helping speaks Traditional Chinese and comes from Taiwan." |
|
|
|
API_URL = os.environ.get("API_URL") |
|
TOKEN = os.environ.get("TOKEN") |
|
|
|
HEADERS = { |
|
"accept": "application/json", |
|
"Authorization": f"Bearer {TOKEN}", |
|
"Content-Type": "application/json", |
|
} |
|
|
|
MODEL_NAME="breeze-7b-instruct-v01" |
|
PRESENCE_PENALTY=0 |
|
FREQUENCY_PENALTY=0 |
|
|
|
model_name = "MediaTek-Research/Breeze-7B-Instruct-v0.1" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(DESCRIPTION) |
|
|
|
chatbot = gr.Chatbot() |
|
with gr.Row(): |
|
msg = gr.Textbox( |
|
container=False, |
|
show_label=False, |
|
placeholder='Type a message...', |
|
scale=10, |
|
) |
|
submit_button = gr.Button('Submit', |
|
variant='primary', |
|
scale=1, |
|
min_width=0) |
|
|
|
with gr.Row(): |
|
retry_button = gr.Button('🔄 Retry', variant='secondary') |
|
undo_button = gr.Button('↩️ Undo', variant='secondary') |
|
clear = gr.Button('🗑️ Clear', variant='secondary') |
|
|
|
saved_input = gr.State() |
|
|
|
with gr.Accordion(label='Advanced options', open=False): |
|
system_prompt = gr.Textbox(label='System prompt', |
|
value=DEFAULT_SYSTEM_PROMPT, |
|
lines=6) |
|
max_new_tokens = gr.Slider( |
|
label='Max new tokens', |
|
minimum=32, |
|
maximum=1024, |
|
step=1, |
|
value=512, |
|
) |
|
temperature = gr.Slider( |
|
label='Temperature', |
|
minimum=0.01, |
|
maximum=0.5, |
|
step=0.01, |
|
value=0.01, |
|
) |
|
top_p = gr.Slider( |
|
label='Top-p (nucleus sampling)', |
|
minimum=0.01, |
|
maximum=1.0, |
|
step=0.01, |
|
value=0.01, |
|
) |
|
|
|
|
|
def user(user_message, history): |
|
return "", history + [[user_message, None]] |
|
|
|
|
|
def bot(history, max_new_tokens, temperature, top_p, system_prompt): |
|
chat_data = [] |
|
system_prompt = system_prompt.strip() |
|
if system_prompt: |
|
chat_data.append({"role": "system", "content": system_prompt}) |
|
for user_msg, assistant_msg in history: |
|
if user_msg is not None: |
|
chat_data.append({"role": "user", "content": user_msg}) |
|
if assistant_msg is not None: |
|
chat_data.append({"role": "assistant", "content": assistant_msg}) |
|
|
|
message = tokenizer.apply_chat_template(chat_data, tokenize=False) |
|
message = message[3:] |
|
|
|
|
|
data = { |
|
"model": MODEL_NAME, |
|
"prompt": str(message), |
|
"temperature": float(temperature) + 0.01, |
|
"n": 1, |
|
"max_tokens": int(max_new_tokens), |
|
"stop": "", |
|
"top_p": float(top_p), |
|
"logprobs": 0, |
|
"echo": False, |
|
"presence_penalty": PRESENCE_PENALTY, |
|
"frequency_penalty": FREQUENCY_PENALTY, |
|
"stream": True, |
|
} |
|
|
|
with requests.post(API_URL, headers=HEADERS, data=json.dumps(data), stream=True) as r: |
|
for response in r.iter_lines(): |
|
if len(response) > 0: |
|
text = response.decode() |
|
if text != "data: [DONE]": |
|
if text.startswith("data: "): |
|
text = text[5:] |
|
delta = json.loads(text)["choices"][0]["text"] |
|
|
|
if history[-1][1] is None: |
|
history[-1][1] = delta |
|
else: |
|
history[-1][1] += delta |
|
yield history |
|
if history[-1][1].endswith('</s>'): |
|
history[-1][1] = history[-1][1][:-4] |
|
yield history |
|
|
|
print('== Record == Query: {query}\n{response}'.format(query=repr(message), response=repr(history[-1][1]))) |
|
|
|
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
fn=bot, |
|
inputs=[ |
|
chatbot, |
|
max_new_tokens, |
|
temperature, |
|
top_p, |
|
system_prompt, |
|
], |
|
outputs=chatbot |
|
) |
|
submit_button.click( |
|
user, [msg, chatbot], [msg, chatbot], queue=False |
|
).then( |
|
fn=bot, |
|
inputs=[ |
|
chatbot, |
|
max_new_tokens, |
|
temperature, |
|
top_p, |
|
system_prompt, |
|
], |
|
outputs=chatbot |
|
) |
|
|
|
|
|
def delete_prev_fn( |
|
history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]: |
|
try: |
|
message, _ = history.pop() |
|
except IndexError: |
|
message = '' |
|
return history, message or '' |
|
|
|
|
|
def display_input(message: str, |
|
history: list[tuple[str, str]]) -> list[tuple[str, str]]: |
|
history.append((message, '')) |
|
return history |
|
|
|
retry_button.click( |
|
fn=delete_prev_fn, |
|
inputs=chatbot, |
|
outputs=[chatbot, saved_input], |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=display_input, |
|
inputs=[saved_input, chatbot], |
|
outputs=chatbot, |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=bot, |
|
inputs=[ |
|
chatbot, |
|
max_new_tokens, |
|
temperature, |
|
top_p, |
|
system_prompt, |
|
], |
|
outputs=chatbot, |
|
) |
|
|
|
undo_button.click( |
|
fn=delete_prev_fn, |
|
inputs=chatbot, |
|
outputs=[chatbot, saved_input], |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=lambda x: x, |
|
inputs=[saved_input], |
|
outputs=msg, |
|
api_name=False, |
|
queue=False, |
|
) |
|
|
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
gr.Markdown(LICENSE) |
|
|
|
demo.queue(concurrency_count=1, max_size=16) |
|
demo.launch() |
|
|