AgrI_Assistant / app.py
shah1zil's picture
Update app.py
c630b36 verified
raw
history blame contribute delete
No virus
4.86 kB
import os
import gradio as gr
import whisper
from gtts import gTTS
import io
from groq import Groq
from PyPDF2 import PdfReader
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
# Set up environment variables
os.environ["GROQ_API_KEY"] = "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"
# Initialize the Groq client
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
# Load the Whisper model
whisper_model = whisper.load_model("base") # You can choose other models like "small", "medium", "large"
# Initialize the tokenizer and model from the saved checkpoint for RAG
# Updated model loading code with disk offloading
# Specify the folder where offloaded model parts will be stored
offload_folder = "./offload"
# Specify the folder where offloaded model parts will be stored
offload_folder = "./offload"
# Ensure the offload folder exists
os.makedirs(offload_folder, exist_ok=True)
# Initialize the tokenizer
rag_tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG")
# Initialize empty weights context
with init_empty_weights():
# Load the model with meta tensors
rag_model = AutoModelForCausalLM.from_pretrained(
"himmeow/vi-gemma-2b-RAG",
torch_dtype=torch.bfloat16,
device_map="auto",
offload_folder=offload_folder
)
# Dispatch the model, ensuring correct device placement and weight loading
rag_model = load_checkpoint_and_dispatch(
rag_model,
"himmeow/vi-gemma-2b-RAG",
device_map="auto",
offload_folder=offload_folder,
offload_state_dict=True
)
# Ensure weights are properly tied if necessary
if hasattr(rag_model, 'tie_weights'):
rag_model.tie_weights()
# Use `to_empty()` to move the model out of the meta state correctly
rag_model = rag_model.to_empty()
# Move model to GPU if available
if torch.cuda.is_available():
rag_model = rag_model.to("cuda")
# Load PDF content
def load_pdf(pdf_path):
pdf_text = ""
with open(pdf_path, "rb") as file:
reader = PdfReader(file)
for page_num in range(len(reader.pages)):
page = reader.pages[page_num]
text = page.extract_text()
pdf_text += text + "\n"
return pdf_text
# Define the prompt format for the RAG model
prompt_template = """
### Instruction and Input:
Based on the following context/document:
{}
Please answer the question: {}
### Response:
{}
"""
# Function to process audio and generate a response using RAG and Groq
def process_audio_rag(file_path):
try:
# Load and transcribe the audio using Whisper
audio = whisper.load_audio(file_path)
result = whisper_model.transcribe(audio)
text = result["text"]
# Load the PDF content (update with your PDF path or pass it as an argument)
pdf_path = "/content/BN_Cotton.pdf"
pdf_text = load_pdf(pdf_path)
# Prepare the input data for the RAG model
query = text
input_text = prompt_template.format(pdf_text, query, " ")
# Encode the input text into input ids for RAG model
input_ids = rag_tokenizer(input_text, return_tensors="pt")
if torch.cuda.is_available():
input_ids = input_ids.to("cuda")
# Generate text using the RAG model
outputs = rag_model.generate(
**input_ids,
max_new_tokens=500,
no_repeat_ngram_size=5
)
rag_response = rag_tokenizer.decode(outputs[0], skip_special_tokens=True)
# Generate a response using Groq if needed
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": rag_response}],
model="llama3-8b-8192", # Replace with the correct model if necessary
)
response_message = chat_completion.choices[0].message.content.strip()
# Convert the response text to speech
tts = gTTS(response_message)
response_audio_io = io.BytesIO()
tts.write_to_fp(response_audio_io)
response_audio_io.seek(0)
# Save audio to a file to ensure it's generated correctly
with open("response.mp3", "wb") as audio_file:
audio_file.write(response_audio_io.getvalue())
# Return the response text and the path to the saved audio file
return response_message, "response.mp3"
except Exception as e:
return f"An error occurred: {e}", None
# Create a Gradio interface
iface = gr.Interface(
fn=process_audio_rag,
inputs=gr.Audio(type="filepath"),
outputs=[gr.Textbox(label="Response Text"), gr.Audio(label="Response Audio")],
live=True,
title="Agriculture Assistant"
)
# Launch the interface with the given title
iface.launch()