import os import datasets import numpy as np import pandas as pd import pymysql.cursors import streamlit as st from streamlit_elements import elements, mui, html, dashboard, nivo from streamlit_extras.switch_page_button import switch_page from streamlit_extras.metric_cards import style_metric_cards from streamlit_extras.stylable_container import stylable_container from pages.Gallery import load_hf_dataset from pages.Ranking import connect_to_db class DashboardApp: def __init__(self, roster, promptBook, session_finished): self.roster = roster self.promptBook = promptBook self.session_finished = session_finished def sidebar(self, tags, mode): with st.sidebar: tag = st.selectbox('Select a tag', tags, key='tag') return tag def leaderboard(self, tag, db_table): tag = '%' if tag == 'all' else tag # get the ranking results of the current user curser = RANKING_CONN.cursor() curser.execute(f"SELECT * FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}' AND tag LIKE '{tag}'") results = curser.fetchall() curser.close() modelVersion_standings = self.score_calculator(results, db_table) # sort the modelVersion_standings by value into a list of tuples in descending order modelVersion_standings = sorted(modelVersion_standings.items(), key=lambda x: x[1], reverse=True) tab1, tab2 = st.tabs(['Top Picks', 'Detailed Info']) with tab1: # self.podium(modelVersion_standings) self.podium_expander(modelVersion_standings) with tab2: st.write('## Detailed information of all selected models') detailed_info = pd.merge(pd.DataFrame(modelVersion_standings, columns=['modelVersion_id', 'ranking_score']), self.roster, on='modelVersion_id') st.data_editor(detailed_info, hide_index=True, disabled=True) def podium(self, modelVersion_standings, n=3): st.write('## Top picks') metric_cols = st.columns(n) image_display = st.empty() for i in range(n): with metric_cols[i]: modelVersion_id = modelVersion_standings[i][0] winning_times = modelVersion_standings[i][1] model_id, model_name, modelVersion_name, url = self.roster[self.roster['modelVersion_id'] == modelVersion_id][['model_id', 'model_name', 'modelVersion_name', 'modelVersion_url']].values[0] metric_card = stylable_container( key="container_with_border", css_styles=""" { border: 1.5px solid rgba(49, 51, 63, 0.2); border-left: 0.5rem solid gold; border-radius: 5px; padding: calc(1em + 5px); gap: 0.5em; box-shadow: 0 0 2rem rgba(0, 0, 0, 0.08); overflow-x: scroll; } """, ) with metric_card: icon = '🥇'if i == 0 else '🥈' if i == 1 else '🥉' # st.write(model_id) st.write(f'### {icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id})') st.write(f'Ranking Score: {winning_times}') show_image = st.button('Show Image', key=modelVersion_id, use_container_width=True) if show_image: images = self.promptBook[self.promptBook['modelVersion_id'] == modelVersion_id]['image_id'].values with image_display.container(): st.write('---') st.write(f'### Images generated with {icon} {model_name}, {modelVersion_name}') col_num = 4 image_cols = st.columns(col_num) for i in range(len(images)): with image_cols[i % col_num]: image = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{images[i]}.png" st.image(image, use_column_width=True) # # st.write('---') # expander = st.expander(f'# {icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{modelVersion_id})') # with expander: # images = self.promptBook[self.promptBook['modelVersion_id'] == modelVersion_id]['image_id'].values # st.write(images) def podium_expander(self, modelVersion_standings, n=3): st.write('## Top picks') # metric_cols = st.columns(n) for i in range(n): # with metric_cols[i]: modelVersion_id = modelVersion_standings[i][0] winning_times = modelVersion_standings[i][1] model_id, model_name, modelVersion_name, url = self.roster[self.roster['modelVersion_id'] == modelVersion_id][['model_id', 'model_name', 'modelVersion_name', 'modelVersion_url']].values[0] icon = '🥇'if i == 0 else '🥈' if i == 1 else '🥉' with st.expander(f'**{icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id})**, Ranking Score: {winning_times}'): images = self.promptBook[self.promptBook['modelVersion_id'] == modelVersion_id]['image_id'].values st.write(f'### Images generated with {icon} {model_name}, {modelVersion_name}') col_num = 4 image_cols = st.columns(col_num) for i in range(len(images)): with image_cols[i % col_num]: image = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{images[i]}.png" st.image(image, use_column_width=True) def score_calculator(self, results, db_table): # sort results by battle time results = sorted(results, key=lambda x: x['battletime']) modelVersion_standings = {} if db_table == 'battle_results': for record in results: modelVersion_standings[record['winner']] = modelVersion_standings.get(record['winner'], 0) + 1 # add the loser who never wins if record['loser'] not in modelVersion_standings: modelVersion_standings[record['loser']] = 0 # add the winning time of the loser to the winner modelVersion_standings[record['winner']] += modelVersion_standings[record['loser']] elif db_table == 'sort_results': pts_map = {'position1': 5, 'position2': 3, 'position3': 1, 'position4': 0} for record in results: for i in range(1, 5): modelVersion_standings[record[f'position{i}']] = modelVersion_standings.get(record[f'position{i}'], 0) + pts_map[f'position{i}'] return modelVersion_standings def app(self): st.title('Your Preferred Models', help="Scores are calculated based on your ranking results.") mode = st.sidebar.radio('Ranking mode', ['Sort', 'Battle'], horizontal=True, index=1) # get tags from database of the current user db_table = 'sort_results' if mode == 'Sort' else 'battle_results' tags = ['all'] curser = RANKING_CONN.cursor() curser.execute( f"SELECT DISTINCT tag FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}'") for row in curser.fetchall(): tags.append(row['tag']) curser.close() if tags == ['all']: st.info(f'No rankings are finished with {mode} mode yet.') else: tag = self.sidebar(tags, mode) self.leaderboard(tag, db_table) st.chat_input('Please leave your comments here.', key='comment') if __name__ == "__main__": st.set_page_config(layout="wide") if 'user_id' not in st.session_state: st.warning('Please log in first.') home_btn = st.button('Go to Home Page') if home_btn: switch_page("home") elif 'progress' not in st.session_state: st.info('You have not checked any image yet. Please go back to the gallery page and check some images.') gallery_btn = st.button('🖼️ Go to Gallery') if gallery_btn: switch_page('gallery') else: session_finished = [] for key, value in st.session_state.progress.items(): if value == 'finished': session_finished.append(key) if len(session_finished) == 0: st.info('A dashboard showing your preferred models will appear after you finish any ranking session.') ranking_btn = st.button('🎖️ Go to Ranking') if ranking_btn: switch_page('ranking') gallery_btn = st.button('🖼️ Go to Gallery') if gallery_btn: switch_page('gallery') else: roster, promptBook, images_ds = load_hf_dataset() RANKING_CONN = connect_to_db() app = DashboardApp(roster, promptBook, session_finished) app.app()