Ricercar commited on
Commit
d504e45
β€’
1 Parent(s): 2386b40

update battle mode

Browse files
pages/Ranking.py CHANGED
@@ -6,6 +6,7 @@ import pandas as pd
6
  import pymysql.cursors
7
  import streamlit as st
8
 
 
9
  from streamlit_elements import elements, mui, html, dashboard, nivo
10
  from streamlit_extras.switch_page_button import switch_page
11
 
@@ -224,6 +225,7 @@ class RankingApp:
224
 
225
  def next_battle(self, prompt_id, image_ids, winner, curr_position, total_num):
226
  loser = 'left' if winner == 'right' else 'right'
 
227
 
228
  curser = RANKING_CONN.cursor()
229
 
@@ -236,8 +238,8 @@ class RankingApp:
236
  curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id, loser_modelVersion_id, winner_modelVersion_id))
237
 
238
  # insert the battle result into the database
239
- query = "INSERT INTO battle_results (username, timestamp, tag, prompt_id, winner, loser) VALUES (%s, %s, %s, %s, %s, %s)"
240
- curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], self.promptBook[self.promptBook['prompt_id'] == prompt_id]['tag'].values[0], prompt_id, winner_modelVersion_id, loser_modelVersion_id))
241
 
242
  curser.close()
243
  RANKING_CONN.commit()
@@ -282,22 +284,27 @@ class RankingApp:
282
  elif st.session_state.progress[prompt_id] == 'finished':
283
  st.write('## You have ranked all models for this tag!')
284
  st.write('Thank you for your participation! Feel free to do the following things:')
285
- st.write('* Rank for other tags and prompts.')
286
- st.write('* Back to the gallery page to see more images.')
287
- st.write('* Rank again for this tag and prompt.')
288
- st.write('*More functions are coming soon... Please stay tuned*')
289
-
290
- gallery_btn = st.button('πŸ–ΌοΈ Back to Gallery')
291
- if gallery_btn:
292
- switch_page('gallery')
293
-
294
- restart_btn = st.button('πŸŽ–οΈ Rank Again')
295
  if restart_btn:
296
  st.session_state.progress[prompt_id] = 'ranking'
297
  st.session_state.counter[prompt_id] = 0
298
  st.session_state.pointer[prompt_id] = {'left': 0, 'right': 1}
299
  st.experimental_rerun()
300
 
 
 
 
 
 
 
 
 
301
 
302
  def connect_to_db():
303
  conn = pymysql.connect(
 
6
  import pymysql.cursors
7
  import streamlit as st
8
 
9
+ from datetime import datetime
10
  from streamlit_elements import elements, mui, html, dashboard, nivo
11
  from streamlit_extras.switch_page_button import switch_page
12
 
 
225
 
226
  def next_battle(self, prompt_id, image_ids, winner, curr_position, total_num):
227
  loser = 'left' if winner == 'right' else 'right'
228
+ battletime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
229
 
230
  curser = RANKING_CONN.cursor()
231
 
 
238
  curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id, loser_modelVersion_id, winner_modelVersion_id))
239
 
240
  # insert the battle result into the database
241
+ query = "INSERT INTO battle_results (username, timestamp, tag, prompt_id, winner, loser, battletime) VALUES (%s, %s, %s, %s, %s, %s, %s)"
242
+ curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], self.promptBook[self.promptBook['prompt_id'] == prompt_id]['tag'].values[0], prompt_id, winner_modelVersion_id, loser_modelVersion_id, battletime))
243
 
244
  curser.close()
245
  RANKING_CONN.commit()
 
284
  elif st.session_state.progress[prompt_id] == 'finished':
285
  st.write('## You have ranked all models for this tag!')
286
  st.write('Thank you for your participation! Feel free to do the following things:')
287
+ # st.write('* Rank for other tags and prompts.')
288
+ # st.write('* Back to the gallery page to see more images.')
289
+ # st.write('* Rank again for this tag and prompt.')
290
+ # st.write('* Check the summary to see what model you like most.')
291
+ # st.write('*More functions are coming soon... Please stay tuned*')
292
+ st.button('πŸ‘ˆ Rank for other tags and prompts')
293
+ restart_btn = st.button('πŸŽ–οΈ Rank this prompt again')
 
 
 
294
  if restart_btn:
295
  st.session_state.progress[prompt_id] = 'ranking'
296
  st.session_state.counter[prompt_id] = 0
297
  st.session_state.pointer[prompt_id] = {'left': 0, 'right': 1}
298
  st.experimental_rerun()
299
 
300
+ gallery_btn = st.button('πŸ–ΌοΈ Back to Gallery')
301
+ if gallery_btn:
302
+ switch_page('gallery')
303
+
304
+ summary_btn = st.button('πŸ“Š See Summary')
305
+ if summary_btn:
306
+ switch_page('summary')
307
+
308
 
309
  def connect_to_db():
310
  conn = pymysql.connect(
pages/{Results.py β†’ Summary.py} RENAMED
@@ -49,20 +49,20 @@ class DashboardApp:
49
  n = 3
50
  metric_cols = st.columns(n)
51
  image_display = st.empty()
52
-
53
  for i in range(n):
54
  with metric_cols[i]:
55
  modelVersion_id = modelVersion_standings[i][0]
56
  winning_times = modelVersion_standings[i][1]
57
 
58
- model_name, modelVersion_name, url = self.roster[self.roster['modelVersion_id'] == modelVersion_id][['model_name', 'modelVersion_name', 'modelVersion_url']].values[0]
59
 
60
  metric_card = stylable_container(
61
  key="container_with_border",
62
  css_styles="""
63
  {
64
- border: 1.5px solid rgba(49, 51, 63, 0.8);
65
- border-left: 0.5rem solid silver;
66
  border-radius: 5px;
67
  padding: calc(1em + 5px);
68
  gap: 0.5em;
@@ -74,8 +74,8 @@ class DashboardApp:
74
 
75
  with metric_card:
76
  icon = 'πŸ₯‡'if i == 0 else 'πŸ₯ˆ' if i == 1 else 'πŸ₯‰'
77
- st.write(modelVersion_id)
78
- st.write(f'### {icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{modelVersion_id})')
79
  st.write(f'Ranking Score: {winning_times}')
80
 
81
  show_image = st.button('Show Image', key=modelVersion_id, use_container_width=True)
@@ -105,20 +105,21 @@ class DashboardApp:
105
 
106
 
107
  def score_calculator(self, results, db_table):
 
 
 
108
  modelVersion_standings = {}
109
  if db_table == 'battle_results':
110
  for record in results:
111
  modelVersion_standings[record['winner']] = modelVersion_standings.get(record['winner'], 0) + 1
112
- # add the winning time of the loser
113
- curser = RANKING_CONN.cursor()
114
- curser.execute(f"SELECT COUNT(*) FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}' AND winner = '{record['loser']}'")
115
- modelVersion_standings[record['winner']] += curser.fetchone()['COUNT(*)']
116
- curser.close()
117
 
118
  # add the loser who never wins
119
  if record['loser'] not in modelVersion_standings:
120
  modelVersion_standings[record['loser']] = 0
121
 
 
 
 
122
  elif db_table == 'sort_results':
123
  pts_map = {'position1': 5, 'position2': 3, 'position3': 1, 'position4': 0}
124
  for record in results:
@@ -128,11 +129,10 @@ class DashboardApp:
128
  return modelVersion_standings
129
 
130
 
131
-
132
  def app(self):
133
  st.title('Your Preferred Models', help="Scores are calculated based on your ranking results.")
134
 
135
- mode = st.sidebar.radio('Ranking mode', ['Sort', 'Battle'], horizontal=True)
136
  # get tags from database of the current user
137
  db_table = 'sort_results' if mode == 'Sort' else 'battle_results'
138
 
 
49
  n = 3
50
  metric_cols = st.columns(n)
51
  image_display = st.empty()
52
+
53
  for i in range(n):
54
  with metric_cols[i]:
55
  modelVersion_id = modelVersion_standings[i][0]
56
  winning_times = modelVersion_standings[i][1]
57
 
58
+ model_id, model_name, modelVersion_name, url = self.roster[self.roster['modelVersion_id'] == modelVersion_id][['model_id', 'model_name', 'modelVersion_name', 'modelVersion_url']].values[0]
59
 
60
  metric_card = stylable_container(
61
  key="container_with_border",
62
  css_styles="""
63
  {
64
+ border: 1.5px solid rgba(49, 51, 63, 0.2);
65
+ border-left: 0.5rem solid gold;
66
  border-radius: 5px;
67
  padding: calc(1em + 5px);
68
  gap: 0.5em;
 
74
 
75
  with metric_card:
76
  icon = 'πŸ₯‡'if i == 0 else 'πŸ₯ˆ' if i == 1 else 'πŸ₯‰'
77
+ # st.write(model_id)
78
+ st.write(f'### {icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id})')
79
  st.write(f'Ranking Score: {winning_times}')
80
 
81
  show_image = st.button('Show Image', key=modelVersion_id, use_container_width=True)
 
105
 
106
 
107
  def score_calculator(self, results, db_table):
108
+ # sort results by battle time
109
+ results = sorted(results, key=lambda x: x['battletime'])
110
+
111
  modelVersion_standings = {}
112
  if db_table == 'battle_results':
113
  for record in results:
114
  modelVersion_standings[record['winner']] = modelVersion_standings.get(record['winner'], 0) + 1
 
 
 
 
 
115
 
116
  # add the loser who never wins
117
  if record['loser'] not in modelVersion_standings:
118
  modelVersion_standings[record['loser']] = 0
119
 
120
+ # add the winning time of the loser to the winner
121
+ modelVersion_standings[record['winner']] += modelVersion_standings[record['loser']]
122
+
123
  elif db_table == 'sort_results':
124
  pts_map = {'position1': 5, 'position2': 3, 'position3': 1, 'position4': 0}
125
  for record in results:
 
129
  return modelVersion_standings
130
 
131
 
 
132
  def app(self):
133
  st.title('Your Preferred Models', help="Scores are calculated based on your ranking results.")
134
 
135
+ mode = st.sidebar.radio('Ranking mode', ['Sort', 'Battle'], horizontal=True, index=1)
136
  # get tags from database of the current user
137
  db_table = 'sort_results' if mode == 'Sort' else 'battle_results'
138