ajv009's picture
Update app.py
b242327 verified
raw
history blame contribute delete
No virus
4.44 kB
import gradio as gr
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("Telugu-LLM-Labs/Indic-gemma-7b-finetuned-sft-Navarasa-2.0")
model = AutoModelForCausalLM.from_pretrained("Telugu-LLM-Labs/Indic-gemma-7b-finetuned-sft-Navarasa-2.0", device_map="auto")
@spaces.GPU(duration=120)
def gemma(message: str,
history: list,
temperature: float,
max_new_tokens: int
) -> str:
"""
Generate a streaming response using the Gemma model.
Args:
message (str): The input message.
history (list): The conversation history used by ChatInterface.
temperature (float): The temperature for generating the response.
max_new_tokens (int): The maximum number of new tokens to generate.
Returns:
str: The generated response.
"""
conversation = []
for user, assistant in history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_prompt = """
### Instruction:
You are an AI assistant. Engage in a conversation with the user and provide helpful responses.
### Input:
{}
### Response:
"""
input_text = input_prompt.format(message)
inputs = tokenizer([input_text], return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
)
# This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
if temperature == 0:
generate_kwargs['do_sample'] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# Gradio block
chatbot=gr.Chatbot(placeholder="Prompt away in your local language",height=500)
with gr.Blocks(fill_height=True) as demo:
gr.ChatInterface(
fn=gemma,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(minimum=0,
maximum=1,
step=0.1,
value=0.95,
label="Temperature",
render=False),
gr.Slider(minimum=128,
maximum=4096,
step=1,
value=512,
label="Max new tokens",
render=False ),
],
examples=[
["Tell me a story of a crow in Malayalam"],
["Explain to me what is AI in Hindi"],
["मुझे एक कौवे की कहानी बताओ"],
["ఒక కాకి కథ చెప్పండి"],
["मला कावळ्याची गोष्ट सांगा"],
["مجھے کوے کی کہانی سناؤ"],
["কাউৰীৰ কাহিনী এটা কওকচোন"],
["मलाई कागको कथा सुनाउनुहोस्"],
["مون کي ڪانءَ جي ڪهاڻي ٻڌاءِ"],
["ஒரு காகத்தின் கதையைச் சொல்லுங்கள்"],
["ಒಂದು ಕಾಗೆಯ ಕಥೆ ಹೇಳು"],
["ഒരു കാക്കയുടെ കഥ പറയൂ"],
["મને કાગડાની વાર્તા કહો"],
["ਮੈਨੂੰ ਇੱਕ ਕਾਂ ਦੀ ਕਹਾਣੀ ਸੁਣਾਓ"],
["একটা কাকের গল্প বল"],
["ମୋତେ କାଉର କାହାଣୀ କୁହ |"]
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()