GEMRec-Gallery / pages /Summary.py
Ricercar's picture
custom radio style
0532088
raw
history blame
No virus
9.98 kB
import os
import datasets
import numpy as np
import pandas as pd
import pymysql.cursors
import streamlit as st
from streamlit_elements import elements, mui, html, dashboard, nivo
from streamlit_extras.switch_page_button import switch_page
from streamlit_extras.metric_cards import style_metric_cards
from streamlit_extras.stylable_container import stylable_container
from pages.Gallery import load_hf_dataset
from pages.Ranking import connect_to_db
class DashboardApp:
def __init__(self, roster, promptBook, session_finished):
self.roster = roster
self.promptBook = promptBook
self.session_finished = session_finished
def sidebar(self, tags, mode):
with st.sidebar:
tag = st.selectbox('Select a tag', tags, key='tag')
st.write('---')
st.write('## Want a more comprehensive summary?')
st.write('Jump back to gallery and select more images to rank!')
back_to_gallery = st.button('πŸ–ΌοΈ Go to Gallery', key='summary_sidebar_gallery')
if back_to_gallery:
switch_page('gallery')
back_to_ranking = st.button('πŸŽ–οΈ Go to Ranking', key='summary_sidebar_ranking')
if back_to_ranking:
switch_page('ranking')
return tag
def leaderboard(self, tag, db_table):
tag = '%' if tag == 'all' else tag
# get the ranking results of the current user
curser = RANKING_CONN.cursor()
curser.execute(f"SELECT * FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}' AND tag LIKE '{tag}'")
results = curser.fetchall()
curser.close()
modelVersion_standings = self.score_calculator(results, db_table)
# sort the modelVersion_standings by value into a list of tuples in descending order
modelVersion_standings = sorted(modelVersion_standings.items(), key=lambda x: x[1], reverse=True)
tab1, tab2 = st.tabs(['Top Picks', 'Detailed Info'])
with tab1:
# self.podium(modelVersion_standings)
self.podium_expander(modelVersion_standings)
with tab2:
st.write('## Detailed information of all selected models')
detailed_info = pd.merge(pd.DataFrame(modelVersion_standings, columns=['modelVersion_id', 'ranking_score']), self.roster, on='modelVersion_id')
st.data_editor(detailed_info, hide_index=True, disabled=True)
def podium(self, modelVersion_standings, n=3):
st.write('## Top picks')
metric_cols = st.columns(n)
image_display = st.empty()
for i in range(n):
with metric_cols[i]:
modelVersion_id = modelVersion_standings[i][0]
winning_times = modelVersion_standings[i][1]
model_id, model_name, modelVersion_name, url = self.roster[self.roster['modelVersion_id'] == modelVersion_id][['model_id', 'model_name', 'modelVersion_name', 'modelVersion_url']].values[0]
metric_card = stylable_container(
key="container_with_border",
css_styles="""
{
border: 1.5px solid rgba(49, 51, 63, 0.2);
border-left: 0.5rem solid gold;
border-radius: 5px;
padding: calc(1em + 5px);
gap: 0.5em;
box-shadow: 0 0 2rem rgba(0, 0, 0, 0.08);
overflow-x: scroll;
}
""",
)
with metric_card:
icon = 'πŸ₯‡'if i == 0 else 'πŸ₯ˆ' if i == 1 else 'πŸ₯‰'
# st.write(model_id)
st.write(f'### {icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id})')
st.write(f'Ranking Score: {winning_times}')
show_image = st.button('Show Image', key=modelVersion_id, use_container_width=True)
if show_image:
images = self.promptBook[self.promptBook['modelVersion_id'] == modelVersion_id]['image_id'].values
with image_display.container():
st.write('---')
st.write(f'### Images generated with {icon} {model_name}, {modelVersion_name}')
col_num = 4
image_cols = st.columns(col_num)
for i in range(len(images)):
with image_cols[i % col_num]:
image = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{images[i]}.png"
st.image(image, use_column_width=True)
# # st.write('---')
# expander = st.expander(f'# {icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{modelVersion_id})')
# with expander:
# images = self.promptBook[self.promptBook['modelVersion_id'] == modelVersion_id]['image_id'].values
# st.write(images)
def podium_expander(self, modelVersion_standings, n=3):
st.write('## Top picks')
# metric_cols = st.columns(n)
for i in range(n):
# with metric_cols[i]:
modelVersion_id = modelVersion_standings[i][0]
winning_times = modelVersion_standings[i][1]
model_id, model_name, modelVersion_name, url = self.roster[self.roster['modelVersion_id'] == modelVersion_id][['model_id', 'model_name', 'modelVersion_name', 'modelVersion_url']].values[0]
icon = 'πŸ₯‡'if i == 0 else 'πŸ₯ˆ' if i == 1 else 'πŸ₯‰'
with st.expander(f'**{icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id})**, Ranking Score: {winning_times}'):
images = self.promptBook[self.promptBook['modelVersion_id'] == modelVersion_id]['image_id'].values
st.write(f'### Images generated with {icon} {model_name}, {modelVersion_name}')
col_num = 4
image_cols = st.columns(col_num)
for i in range(len(images)):
with image_cols[i % col_num]:
image = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{images[i]}.png"
st.image(image, use_column_width=True)
def score_calculator(self, results, db_table):
modelVersion_standings = {}
if db_table == 'battle_results':
# sort results by battle time
results = sorted(results, key=lambda x: x['battletime'])
for record in results:
modelVersion_standings[record['winner']] = modelVersion_standings.get(record['winner'], 0) + 1
# add the loser who never wins
if record['loser'] not in modelVersion_standings:
modelVersion_standings[record['loser']] = 0
# add the winning time of the loser to the winner
modelVersion_standings[record['winner']] += modelVersion_standings[record['loser']]
elif db_table == 'sort_results':
pts_map = {'position1': 5, 'position2': 3, 'position3': 1, 'position4': 0}
for record in results:
for i in range(1, 5):
modelVersion_standings[record[f'position{i}']] = modelVersion_standings.get(record[f'position{i}'], 0) + pts_map[f'position{i}']
return modelVersion_standings
def app(self):
st.title('Your Preferred Models', help="Scores are calculated based on your ranking results.")
# mode = st.sidebar.radio('Ranking mode', ['Drag and Sort', 'Battle'], horizontal=True, index=1)
mode = st.session_state.assigned_rank_mode
# get tags from database of the current user
db_table = 'sort_results' if mode == 'Drag and Sort' else 'battle_results'
tags = ['all']
curser = RANKING_CONN.cursor()
curser.execute(
f"SELECT DISTINCT tag FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}'")
for row in curser.fetchall():
tags.append(row['tag'])
curser.close()
if tags == ['all']:
st.info(f'No rankings are finished with {mode} mode yet.')
else:
tag = self.sidebar(tags, mode)
self.leaderboard(tag, db_table)
st.chat_input('Please leave your comments here.', key='comment')
if __name__ == "__main__":
st.set_page_config(layout="wide")
if 'user_id' not in st.session_state:
st.warning('Please log in first.')
home_btn = st.button('Go to Home Page')
if home_btn:
switch_page("home")
elif 'progress' not in st.session_state:
st.info('You have not checked any image yet. Please go back to the gallery page and check some images.')
gallery_btn = st.button('πŸ–ΌοΈ Go to Gallery')
if gallery_btn:
switch_page('gallery')
else:
session_finished = []
for key, value in st.session_state.progress.items():
if value == 'finished':
session_finished.append(key)
if len(session_finished) == 0:
st.info('A dashboard showing your preferred models will appear after you finish any ranking session.')
ranking_btn = st.button('πŸŽ–οΈ Go to Ranking')
if ranking_btn:
switch_page('ranking')
gallery_btn = st.button('πŸ–ΌοΈ Go to Gallery')
if gallery_btn:
switch_page('gallery')
else:
roster, promptBook, images_ds = load_hf_dataset()
RANKING_CONN = connect_to_db()
app = DashboardApp(roster, promptBook, session_finished)
app.app()