File size: 4,115 Bytes
892f733
 
970e9e7
892f733
 
970e9e7
892f733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
970e9e7
 
 
892f733
733ec1d
970e9e7
 
3845117
892f733
 
 
 
970e9e7
892f733
970e9e7
892f733
 
 
3f1bfc5
892f733
 
 
 
 
970e9e7
892f733
 
 
 
970e9e7
892f733
 
970e9e7
892f733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
970e9e7
 
 
733ec1d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from threading import Thread

import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer


TITLE = "<h1><center>Chat with Gemma-2-27B-Chinese-Chat</center></h1>"

DESCRIPTION = "<h3><center>Visit <a href='https://huggingface.co/shenzhi-wang/Gemma-2-27B-Chinese-Chat' target='_blank'>our model page</a> for details.</center></h3>"

DEFAULT_SYSTEM = "You are a helpful assistant."

TOOL_EXAMPLE = '''You have access to the following tools:
```python
def generate_password(length: int, include_symbols: Optional[bool]):
    """
    Generate a random password.
    Args:
        length (int): The length of the password
        include_symbols (Optional[bool]): Include symbols in the password
    """
    pass
```
Write "Action:" followed by a list of actions in JSON that you want to call, e.g.
Action:
```json
[
    {
        "name": "tool name (one of [generate_password])",
        "arguments": "the input to the tool"
    }
]
```
'''

CSS = """
.duplicate-button {
  margin: auto !important;
  color: white !important;
  background: black !important;
  border-radius: 100vh !important;
}
"""


tokenizer = AutoTokenizer.from_pretrained("shenzhi-wang/Gemma-2-27B-Chinese-Chat")
model = AutoModelForCausalLM.from_pretrained("shenzhi-wang/Gemma-2-27B-Chinese-Chat", device_map="auto", torch_dtype="auto")


@spaces.GPU(duration=360)
def stream_chat(message: str, history: list, system: str, temperature: float, max_new_tokens: int):
    conversation = [{"role": "system", "content": system or DEFAULT_SYSTEM}]
    for prompt, answer in history:
        conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])

    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(
        model.device
    )
    streamer = TextIteratorStreamer(tokenizer, timeout=360.0, skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        do_sample=True,
    )
    if temperature == 0:
        generate_kwargs["do_sample"] = False

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    output = ""
    for new_token in streamer:
        output += new_token
        yield output


chatbot = gr.Chatbot(height=450)

with gr.Blocks(css=CSS) as demo:
    gr.HTML(TITLE)
    gr.HTML(DESCRIPTION)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
    gr.ChatInterface(
        fn=stream_chat,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Text(
                value="",
                label="System",
                render=False,
            ),
            gr.Slider(
                minimum=0,
                maximum=1,
                step=0.1,
                value=0.8,
                label="Temperature",
                render=False,
            ),
            gr.Slider(
                minimum=128,
                maximum=4096,
                step=1,
                value=1024,
                label="Max new tokens",
                render=False,
            ),
        ],
        examples=[
            ["我的蓝牙耳机坏了,我该去看牙科还是耳鼻喉科?", ""],
            ["7年前,妈妈年龄是儿子的6倍,儿子今年12岁,妈妈今年多少岁?", ""],
            ["我的笔记本找不到了。", "扮演诸葛亮和我对话。"],
            ["我想要一个新的密码,长度为8位,包含特殊符号。", TOOL_EXAMPLE],
            ["How are you today?", "You are Taylor Swift, use beautiful lyrics to answer questions."],
            ["用C++实现KMP算法,并加上中文注释", ""],
        ],
        cache_examples=False,
    )


if __name__ == "__main__":
    demo.launch()