File size: 4,439 Bytes
348c912
b35e8ee
 
 
 
348c912
b35e8ee
 
 
 
 
 
e3db908
 
b35e8ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a38ad9
 
 
 
 
 
 
 
 
 
 
 
 
b35e8ee
 
 
 
0a38ad9
b35e8ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a38ad9
b35e8ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b242327
 
b35e8ee
 
 
 
 
 
 
 
 
 
 
 
 
b242327
b35e8ee
 
 
 
 
0a38ad9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread


# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)

    
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("Telugu-LLM-Labs/Indic-gemma-7b-finetuned-sft-Navarasa-2.0")
model = AutoModelForCausalLM.from_pretrained("Telugu-LLM-Labs/Indic-gemma-7b-finetuned-sft-Navarasa-2.0", device_map="auto")


@spaces.GPU(duration=120)
def gemma(message: str, 
              history: list, 
              temperature: float, 
              max_new_tokens: int
             ) -> str:
    """
    Generate a streaming response using the Gemma model.
    Args:
        message (str): The input message.
        history (list): The conversation history used by ChatInterface.
        temperature (float): The temperature for generating the response.
        max_new_tokens (int): The maximum number of new tokens to generate.
    Returns:
        str: The generated response.
    """
    conversation = []
    for user, assistant in history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_prompt = """
### Instruction:
You are an AI assistant. Engage in a conversation with the user and provide helpful responses.

### Input:
{}

### Response:
"""

    input_text = input_prompt.format(message)

    inputs = tokenizer([input_text], return_tensors="pt").to(model.device)
    
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
    )
    # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.             
    if temperature == 0:
        generate_kwargs['do_sample'] = False
        
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)
        

# Gradio block
chatbot=gr.Chatbot(placeholder="Prompt away in your local language",height=500)

with gr.Blocks(fill_height=True) as demo:
    
    gr.ChatInterface(
        fn=gemma,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Slider(minimum=0,
                      maximum=1, 
                      step=0.1,
                      value=0.95, 
                      label="Temperature", 
                      render=False),
            gr.Slider(minimum=128, 
                      maximum=4096,
                      step=1,
                      value=512, 
                      label="Max new tokens", 
                      render=False ),
            ],
        examples=[
            ["Tell me a story of a crow in Malayalam"],
            ["Explain to me what is AI in Hindi"],
            ["मुझे एक कौवे की कहानी बताओ"],
            ["ఒక కాకి కథ చెప్పండి"],
            ["मला कावळ्याची गोष्ट सांगा"],
            ["مجھے کوے کی کہانی سناؤ"],
            ["কাউৰীৰ কাহিনী এটা কওকচোন"],
            ["मलाई कागको कथा सुनाउनुहोस्"],
            ["مون کي ڪانءَ جي ڪهاڻي ٻڌاءِ"],
            ["ஒரு காகத்தின் கதையைச் சொல்லுங்கள்"],
            ["ಒಂದು ಕಾಗೆಯ ಕಥೆ ಹೇಳು"],
            ["ഒരു കാക്കയുടെ കഥ പറയൂ"],
            ["મને કાગડાની વાર્તા કહો"],
            ["ਮੈਨੂੰ ਇੱਕ ਕਾਂ ਦੀ ਕਹਾਣੀ ਸੁਣਾਓ"],
            ["একটা কাকের গল্প বল"],
            ["ମୋତେ କାଉର କାହାଣୀ କୁହ |"]
            ],
        cache_examples=False,
                     )
    
if __name__ == "__main__":
    demo.launch()