Mile-stone-3 / app.py
kya5's picture
Duplicate from kya5/milestone-3
9047480
raw
history blame contribute delete
No virus
2.11 kB
import streamlit as st
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import pandas as pd
import random
classifiers = ['toxic', 'severe_toxic', 'obscene',
'threat', 'insult', 'identity_hate']
def reset_scores():
global scores_df
scores_df = pd.DataFrame(columns=['Comment'] + classifiers)
def get_score(model_base, text):
if model_base == "bert-base-cased":
model_dir = "./bert/_bert_model"
elif model_base == "distilbert-base-cased":
model_dir = "./distilbert/_distilbert_model"
else:
model_dir = "./roberta/_roberta_model"
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_base)
inputs = tokenizer.encode_plus(
text, max_length=512, truncation=True, padding=True, return_tensors='pt')
outputs = model(**inputs)
predictions = torch.sigmoid(outputs.logits)
return predictions
st.title("Toxic Comment Classifier")
model_base = st.selectbox("Select a pretrained model",
["roberta-base", "bert-base-cased", "distilbert-base-cased"])
text_input = st.text_input("Enter text for toxicity classification",
"")
submit_btn = st.button("Submit")
if submit_btn and text_input:
result = get_score(model_base, text_input)
df = pd.DataFrame([result[0].tolist()], columns=classifiers)
df = df.round(2) # Round the values to 2 decimal places
df = df.applymap(lambda x: '{:.0%}'.format(x))
st.table(df)
test_df = pd.read_csv(
"./jigsaw-toxic-comment-classification-challenge/test.csv")
sample_df = test_df.sample(n=3)
reset_scores()
for index, row in sample_df.iterrows():
result = get_score(model_base, row['comment_text'])
scores = result[0].tolist()
scores_df.loc[len(scores_df)] = [row['comment_text']] + scores
scores_df = scores_df.round(2)
st.subheader("Toxicity Scores for Random Comments")
if st.button("Refresh"):
reset_scores()
st.success("New tweets have been loaded!")
st.table(scores_df)