Ricercar commited on
Commit
06ca453
โ€ข
1 Parent(s): bfca1a2

update two-stage interface

Browse files
Files changed (3) hide show
  1. Home.py +2 -0
  2. pages/Gallery.py +13 -6
  3. pages/Ranking.py +19 -12
Home.py CHANGED
@@ -30,6 +30,7 @@ def save_user_id(user_id):
30
  if not user_id:
31
  user_id = 'anonymous' + str(random.randint(0, 100000))
32
  st.session_state.user_id = [user_id, datetime.now().strftime("%Y-%m-%d %H:%M:%S")]
 
33
 
34
 
35
  def logout():
@@ -38,6 +39,7 @@ def logout():
38
  st.session_state.pop('score_weights', None)
39
  st.session_state.pop('gallery_state', None)
40
  st.session_state.pop('progress', None)
 
41
 
42
 
43
  def info():
 
30
  if not user_id:
31
  user_id = 'anonymous' + str(random.randint(0, 100000))
32
  st.session_state.user_id = [user_id, datetime.now().strftime("%Y-%m-%d %H:%M:%S")]
33
+ st.session_state.assigned_rank_mode = random.choice(['sort', 'battle'])
34
 
35
 
36
  def logout():
 
39
  st.session_state.pop('score_weights', None)
40
  st.session_state.pop('gallery_state', None)
41
  st.session_state.pop('progress', None)
42
+ st.session_state.pop('gallery_focus', None)
43
 
44
 
45
  def info():
pages/Gallery.py CHANGED
@@ -32,6 +32,9 @@ class GalleryApp:
32
  if 'selected_dict' not in st.session_state:
33
  st.session_state['selected_dict'] = {}
34
 
 
 
 
35
  def gallery_standard(self, items, col_num, info):
36
  rows = len(items) // col_num + 1
37
  containers = [st.container() for _ in range(rows)]
@@ -310,7 +313,7 @@ class GalleryApp:
310
  # st.markdown(':orange[Please select a prompt above๐Ÿ‘†]')
311
  st.write('**Feel free to navigate among tags and pages! Your selection will be saved within one log-in session.**')
312
 
313
- with subset_selector[1]:
314
  st.write(':orange[๐Ÿ‘ˆ **Please select a prompt**]')
315
 
316
  else:
@@ -322,6 +325,10 @@ class GalleryApp:
322
  if prompt_id not in st.session_state.gallery_state:
323
  st.session_state.gallery_state[prompt_id] = 'graph'
324
 
 
 
 
 
325
  # add safety check for some prompts
326
  safety_check = True
327
  unsafe_prompts = {}
@@ -348,7 +355,7 @@ class GalleryApp:
348
  # # st.warning('No selected images found')
349
  # else:
350
  self.graph_mode(prompt_id, items)
351
- with subset_selector[1]:
352
  # if st.session_state.gallery_state[prompt_id] == 'graph':
353
  # subset = st.selectbox('Select a subset', ['All', 'Selected Only'], index=0, key=f'subset_{tag}')
354
  has_selection = False
@@ -359,7 +366,7 @@ class GalleryApp:
359
  pass
360
 
361
  if has_selection:
362
- checkout = st.button('Check out selections', use_container_width=True, type='primary')
363
  if checkout:
364
  print('checkout')
365
 
@@ -367,17 +374,17 @@ class GalleryApp:
367
  print(st.session_state.gallery_state[prompt_id])
368
  st.experimental_rerun()
369
  else:
370
- st.write('Select images you like below ๐Ÿ‘‡')
371
 
372
  elif st.session_state.gallery_state[prompt_id] == 'gallery':
373
  items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(
374
  drop=True)
375
  self.gallery_mode(prompt_id, items)
376
 
377
- with subset_selector[1]:
378
  state_operations = st.columns([1, 1])
379
  with state_operations[0]:
380
- back = st.button('Back', use_container_width=True)
381
  if back:
382
  st.session_state.gallery_state[prompt_id] = 'graph'
383
  st.experimental_rerun()
 
32
  if 'selected_dict' not in st.session_state:
33
  st.session_state['selected_dict'] = {}
34
 
35
+ if 'gallery_focus' not in st.session_state:
36
+ st.session_state.gallery_focus = {'tag': None, 'prompt': None}
37
+
38
  def gallery_standard(self, items, col_num, info):
39
  rows = len(items) // col_num + 1
40
  containers = [st.container() for _ in range(rows)]
 
313
  # st.markdown(':orange[Please select a prompt above๐Ÿ‘†]')
314
  st.write('**Feel free to navigate among tags and pages! Your selection will be saved within one log-in session.**')
315
 
316
+ with subset_selector[-1]:
317
  st.write(':orange[๐Ÿ‘ˆ **Please select a prompt**]')
318
 
319
  else:
 
325
  if prompt_id not in st.session_state.gallery_state:
326
  st.session_state.gallery_state[prompt_id] = 'graph'
327
 
328
+ # add focus to session state
329
+ st.session_state.gallery_focus['tag'] = tag
330
+ st.session_state.gallery_focus['prompt'] = selected_prompt
331
+
332
  # add safety check for some prompts
333
  safety_check = True
334
  unsafe_prompts = {}
 
355
  # # st.warning('No selected images found')
356
  # else:
357
  self.graph_mode(prompt_id, items)
358
+ with subset_selector[-1]:
359
  # if st.session_state.gallery_state[prompt_id] == 'graph':
360
  # subset = st.selectbox('Select a subset', ['All', 'Selected Only'], index=0, key=f'subset_{tag}')
361
  has_selection = False
 
366
  pass
367
 
368
  if has_selection:
369
+ checkout = st.button('๐Ÿ›’ Check out selections', use_container_width=True, type='primary')
370
  if checkout:
371
  print('checkout')
372
 
 
374
  print(st.session_state.gallery_state[prompt_id])
375
  st.experimental_rerun()
376
  else:
377
+ st.write(':orange[๐Ÿ‘‡ **Select images you like below**]')
378
 
379
  elif st.session_state.gallery_state[prompt_id] == 'gallery':
380
  items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(
381
  drop=True)
382
  self.gallery_mode(prompt_id, items)
383
 
384
+ with subset_selector[-1]:
385
  state_operations = st.columns([1, 1])
386
  with state_operations[0]:
387
+ back = st.button('Back to ๐Ÿ–ผ๏ธ', use_container_width=True)
388
  if back:
389
  st.session_state.gallery_state[prompt_id] = 'graph'
390
  st.experimental_rerun()
pages/Ranking.py CHANGED
@@ -27,13 +27,20 @@ class RankingApp:
27
  def sidebar(self):
28
  with st.sidebar:
29
  prompt_tags = self.promptBook['tag'].unique()
30
- prompt_tags = np.sort(prompt_tags)
31
 
32
- tag = st.selectbox('Select a prompt tag', prompt_tags)
 
 
 
 
33
  items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
34
- prompts = np.sort(items['prompt'].unique())[::-1]
 
 
 
35
 
36
- selected_prompt = st.selectbox('Select a prompt', prompts)
37
 
38
  mode = st.radio('Select a mode', ['Drag and Sort', 'Battle'], index=1)
39
 
@@ -204,24 +211,24 @@ class RankingApp:
204
  with left:
205
  image_id = items['image_id'][st.session_state.pointer[prompt_id]['left']]
206
  img_url = self.images_endpoint + str(image_id) + '.png'
207
- st.image(img_url, use_column_width=True)
208
 
209
- # write the total score of this image
210
- total_score = items['total_score'][st.session_state.pointer[prompt_id]['left']]
211
- st.write(f'Total Score: {total_score}')
212
 
213
  btn_left = st.button('Left is better', key='left', on_click=self.next_battle, kwargs={'prompt_id': prompt_id, 'image_ids': items['image_id'], 'winner': 'left', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
 
214
 
215
  with right:
216
  image_id = items['image_id'][st.session_state.pointer[prompt_id]['right']]
217
  img_url = self.images_endpoint + str(image_id) + '.png'
218
- st.image(img_url, use_column_width=True)
219
 
220
- # write the total score of this image
221
- total_score = items['total_score'][st.session_state.pointer[prompt_id]['right']]
222
- st.write(f'Total Score: {total_score}')
223
 
224
  btn_right = st.button('Right is better', key='right', on_click=self.next_battle, kwargs={'prompt_id': prompt_id, 'image_ids': items['image_id'], 'winner': 'right', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
 
225
 
226
  def next_battle(self, prompt_id, image_ids, winner, curr_position, total_num):
227
  loser = 'left' if winner == 'right' else 'right'
 
27
  def sidebar(self):
28
  with st.sidebar:
29
  prompt_tags = self.promptBook['tag'].unique()
30
+ prompt_tags = np.sort(prompt_tags).tolist()
31
 
32
+ print(st.session_state.gallery_focus)
33
+ tag_idx = prompt_tags.index(st.session_state.gallery_focus['tag']) if st.session_state.gallery_focus['tag'] in prompt_tags else 0
34
+ print(tag_idx)
35
+
36
+ tag = st.selectbox('Select a prompt tag', prompt_tags, index=tag_idx)
37
  items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
38
+ prompts = np.sort(items['prompt'].unique())[::-1].tolist()
39
+
40
+ prompt_idx = prompts.index(st.session_state.gallery_focus['prompt']) if st.session_state.gallery_focus['prompt'] in prompts else 0
41
+ print(prompt_idx)
42
 
43
+ selected_prompt = st.selectbox('Select a prompt', prompts, index=prompt_idx)
44
 
45
  mode = st.radio('Select a mode', ['Drag and Sort', 'Battle'], index=1)
46
 
 
211
  with left:
212
  image_id = items['image_id'][st.session_state.pointer[prompt_id]['left']]
213
  img_url = self.images_endpoint + str(image_id) + '.png'
 
214
 
215
+ # # write the total score of this image
216
+ # total_score = items['total_score'][st.session_state.pointer[prompt_id]['left']]
217
+ # st.write(f'Total Score: {total_score}')
218
 
219
  btn_left = st.button('Left is better', key='left', on_click=self.next_battle, kwargs={'prompt_id': prompt_id, 'image_ids': items['image_id'], 'winner': 'left', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
220
+ st.image(img_url, use_column_width=True)
221
 
222
  with right:
223
  image_id = items['image_id'][st.session_state.pointer[prompt_id]['right']]
224
  img_url = self.images_endpoint + str(image_id) + '.png'
 
225
 
226
+ # # write the total score of this image
227
+ # total_score = items['total_score'][st.session_state.pointer[prompt_id]['right']]
228
+ # st.write(f'Total Score: {total_score}')
229
 
230
  btn_right = st.button('Right is better', key='right', on_click=self.next_battle, kwargs={'prompt_id': prompt_id, 'image_ids': items['image_id'], 'winner': 'right', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
231
+ st.image(img_url, use_column_width=True)
232
 
233
  def next_battle(self, prompt_id, image_ids, winner, curr_position, total_num):
234
  loser = 'left' if winner == 'right' else 'right'