amiguel's picture
Update app.py
c9b60f0 verified
raw
history blame contribute delete
No virus
1.55 kB
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import AutoPeftModelForCausalLM
# Load the model and tokenizer
model_name = "amiguel/itemClassification_Alpaca_Mistral"
model = AutoPeftModelForCausalLM.from_pretrained(model_name, load_in_4bit = load_in_4bit,)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Set the device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Create a Streamlit app
st.title("GPT2 Text Generation App")
# Create input fields for the prompt
prompt = st.text_input("Enter the prompt:")
max_length = st.number_input("Enter the maximum length of the generated text:", value=100)
num_beams = st.number_input("Enter the number of beams:", value=4)
# Create an output field
output_field = st.empty()
# Define the inference function
def infer():
# Tokenize the prompt
inputs = tokenizer([prompt], return_tensors="pt", max_length=max_length, truncation=True)
# Move the inputs to the device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Run inference
outputs = model.generate(**inputs, max_length=max_length, num_beams=num_beams)
# Convert the output to a string
output_str = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Display the output
output_field.text(output_str)
# Create a button to trigger the inference
infer_button = st.button("Generate Text")
# Run the inference function when the button is clicked
if infer_button:
infer()