import streamlit as st from transformers import AutoModelForSequenceClassification, AutoTokenizer import torch import pandas as pd import random classifiers = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] def reset_scores(): global scores_df scores_df = pd.DataFrame(columns=['Comment'] + classifiers) def get_score(model_base, text): if model_base == "bert-base-cased": model_dir = "./bert/_bert_model" elif model_base == "distilbert-base-cased": model_dir = "./distilbert/_distilbert_model" else: model_dir = "./roberta/_roberta_model" model = AutoModelForSequenceClassification.from_pretrained(model_dir) tokenizer = AutoTokenizer.from_pretrained(model_base) inputs = tokenizer.encode_plus( text, max_length=512, truncation=True, padding=True, return_tensors='pt') outputs = model(**inputs) predictions = torch.sigmoid(outputs.logits) return predictions st.title("Toxic Comment Classifier") model_base = st.selectbox("Select a pretrained model", ["roberta-base", "bert-base-cased", "distilbert-base-cased"]) text_input = st.text_input("Enter text for toxicity classification", "") submit_btn = st.button("Submit") if submit_btn and text_input: result = get_score(model_base, text_input) df = pd.DataFrame([result[0].tolist()], columns=classifiers) df = df.round(2) # Round the values to 2 decimal places df = df.applymap(lambda x: '{:.0%}'.format(x)) st.table(df) test_df = pd.read_csv( "./jigsaw-toxic-comment-classification-challenge/test.csv") sample_df = test_df.sample(n=3) reset_scores() for index, row in sample_df.iterrows(): result = get_score(model_base, row['comment_text']) scores = result[0].tolist() scores_df.loc[len(scores_df)] = [row['comment_text']] + scores scores_df = scores_df.round(2) st.subheader("Toxicity Scores for Random Comments") if st.button("Refresh"): reset_scores() st.success("New tweets have been loaded!") st.table(scores_df)