Ricercar commited on
Commit
bfca1a2
1 Parent(s): c9b39a0

add two-substage gallery page

Browse files
Files changed (3) hide show
  1. Home.py +2 -0
  2. pages/Gallery.py +104 -45
  3. pages/Summary.py +3 -3
Home.py CHANGED
@@ -36,6 +36,8 @@ def logout():
36
  st.session_state.pop('user_id', None)
37
  st.session_state.pop('selected_dict', None)
38
  st.session_state.pop('score_weights', None)
 
 
39
 
40
 
41
  def info():
 
36
  st.session_state.pop('user_id', None)
37
  st.session_state.pop('selected_dict', None)
38
  st.session_state.pop('score_weights', None)
39
+ st.session_state.pop('gallery_state', None)
40
+ st.session_state.pop('progress', None)
41
 
42
 
43
  def info():
pages/Gallery.py CHANGED
@@ -24,6 +24,14 @@ class GalleryApp:
24
  self.promptBook = promptBook
25
  self.images_ds = images_ds
26
 
 
 
 
 
 
 
 
 
27
  def gallery_standard(self, items, col_num, info):
28
  rows = len(items) // col_num + 1
29
  containers = [st.container() for _ in range(rows)]
@@ -276,6 +284,7 @@ class GalleryApp:
276
  # chosen_data = [stx.TabBarItemData(id=tag, title=tag, description='') for tag in prompt_tags]
277
  # tag = stx.tab_bar(chosen_data, key='tag', default='food')
278
 
 
279
  tag = st.radio('Select a tag', prompt_tags, index=5, horizontal=True, key='tag')
280
 
281
  # tabs = st.tabs(prompt_tags)
@@ -284,23 +293,34 @@ class GalleryApp:
284
  # tag = prompt_tags[i]
285
  items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
286
 
287
- prompts = np.sort(items['prompt'].unique())[::1]
288
 
 
289
  subset_selector = st.columns([3, 1])
290
  with subset_selector[0]:
 
 
 
 
291
  # selected_prompt = st.selectbox('Select prompt', prompts, index=3)
292
- selected_prompt = selectbox('Select prompt', prompts, key=f'prompt_{tag}', no_selection_label='---')
293
- with subset_selector[1]:
294
- subset = st.selectbox('Select a subset', ['All', 'Selected Only'], index=0, key=f'subset_{tag}')
295
 
296
  if selected_prompt is None:
297
- st.markdown(':orange[Please select a prompt above👆]')
298
  st.write('**Feel free to navigate among tags and pages! Your selection will be saved within one log-in session.**')
 
 
 
 
299
  else:
300
  items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
301
  prompt_id = items['prompt_id'].unique()[0]
302
  note = items['note'].unique()[0]
303
- print(prompt_id, note)
 
 
 
304
 
305
  # add safety check for some prompts
306
  safety_check = True
@@ -316,16 +336,61 @@ class GalleryApp:
316
  st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
317
  safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
318
 
319
- if safety_check:
320
-
321
- if subset == 'Selected Only' and 'selected_dict' in st.session_state:
322
- # try:
323
- items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
324
- self.gallery_mode(prompt_id, items)
325
- # except:
326
- # st.warning('No selected images found')
327
- else:
 
 
328
  self.graph_mode(prompt_id, items)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
  try:
331
  self.sidebar(items, prompt_id, note)
@@ -390,8 +455,6 @@ class GalleryApp:
390
  infos_df = infos_df.rename(index={'model_name': 'Model', 'modelVersion_name': 'Version', 'model_download_count': 'Downloads', 'clip_score': 'Clip Score', 'mcos_score': 'mcos Score', 'nsfw_score': 'NSFW Score'})
391
  st.table(infos_df)
392
 
393
- st.button('🎖️ Proceed selections to ranking', on_click=switch_page, args=("ranking",), use_container_width=True,)
394
-
395
  # for info in infos:
396
  # st.write(f"**{info}**:")
397
  # st.write(item[info])
@@ -399,8 +462,6 @@ class GalleryApp:
399
  else:
400
  st.info('Please click on an image to show')
401
 
402
-
403
-
404
  def gallery_mode(self, prompt_id, items):
405
  items, info, col_num = self.selection_panel(items)
406
 
@@ -423,33 +484,31 @@ class GalleryApp:
423
  # if prompt:
424
  # switch_page("ranking")
425
 
426
- with st.form(key=f'{prompt_id}'):
427
  # buttons = st.columns([1, 1, 1])
428
- buttons_space = st.columns([1, 1, 1])
429
- gallery_space = st.empty()
430
-
431
- with buttons_space[0]:
432
- continue_btn = st.form_submit_button('Proceed selections to ranking', use_container_width=True, type='primary')
433
- if continue_btn:
434
- # self.submit_actions('Continue', prompt_id)
435
- switch_page("ranking")
436
-
437
- with buttons_space[1]:
438
- deselect_btn = st.form_submit_button('Deselect All', use_container_width=True)
439
- if deselect_btn:
440
- self.submit_actions('Deselect', prompt_id)
441
-
442
- with buttons_space[2]:
443
- refresh_btn = st.form_submit_button('Refresh', on_click=gallery_space.empty, use_container_width=True)
444
 
445
- with gallery_space.container():
446
- with st.spinner('Loading images...'):
447
- self.gallery_standard(items, col_num, info)
448
 
449
  # st.info("Don't forget to scroll back to top and click the 'Confirm Selection' button to save your selection!!!")
450
 
451
-
452
-
453
  def submit_actions(self, status, prompt_id):
454
  # remove counter from session state
455
  # st.session_state.pop('counter', None)
@@ -473,7 +532,7 @@ class GalleryApp:
473
  st.session_state.selected_dict[prompt_id].append(int(keys[2]))
474
  # switch_page("ranking")
475
  print(st.session_state.selected_dict, 'continue')
476
- st.experimental_rerun()
477
 
478
  def dynamic_weight(self, prompt_id, items, method='Grid Search'):
479
  selected = items[
@@ -656,9 +715,9 @@ if __name__ == "__main__":
656
  roster, promptBook, images_ds = load_hf_dataset()
657
  # print(promptBook.columns)
658
 
659
- # initialize selected_dict
660
- if 'selected_dict' not in st.session_state:
661
- st.session_state['selected_dict'] = {}
662
 
663
  app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
664
  app.app()
 
24
  self.promptBook = promptBook
25
  self.images_ds = images_ds
26
 
27
+ # init gallery state
28
+ if 'gallery_state' not in st.session_state:
29
+ st.session_state.gallery_state = {}
30
+
31
+ # initialize selected_dict
32
+ if 'selected_dict' not in st.session_state:
33
+ st.session_state['selected_dict'] = {}
34
+
35
  def gallery_standard(self, items, col_num, info):
36
  rows = len(items) // col_num + 1
37
  containers = [st.container() for _ in range(rows)]
 
284
  # chosen_data = [stx.TabBarItemData(id=tag, title=tag, description='') for tag in prompt_tags]
285
  # tag = stx.tab_bar(chosen_data, key='tag', default='food')
286
 
287
+ # save tag to session state on change
288
  tag = st.radio('Select a tag', prompt_tags, index=5, horizontal=True, key='tag')
289
 
290
  # tabs = st.tabs(prompt_tags)
 
293
  # tag = prompt_tags[i]
294
  items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
295
 
296
+ prompts = np.sort(items['prompt'].unique())[::1].tolist()
297
 
298
+ st.caption('Select a prompt')
299
  subset_selector = st.columns([3, 1])
300
  with subset_selector[0]:
301
+ # remember last prompt
302
+ # if 'prompt_idx_last_time' not in st.session_state:
303
+ # st.session_state.prompt_idx_last_time = 0
304
+
305
  # selected_prompt = st.selectbox('Select prompt', prompts, index=3)
306
+ selected_prompt = selectbox('Select prompt', prompts, key=f'prompt_{tag}', no_selection_label='---', label_visibility='collapsed', index=0)
307
+ # st.session_state.prompt_idx_last_time = prompts.index(selected_prompt) if selected_prompt else 0
 
308
 
309
  if selected_prompt is None:
310
+ # st.markdown(':orange[Please select a prompt above👆]')
311
  st.write('**Feel free to navigate among tags and pages! Your selection will be saved within one log-in session.**')
312
+
313
+ with subset_selector[1]:
314
+ st.write(':orange[👈 **Please select a prompt**]')
315
+
316
  else:
317
  items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
318
  prompt_id = items['prompt_id'].unique()[0]
319
  note = items['note'].unique()[0]
320
+
321
+ # add state to session state
322
+ if prompt_id not in st.session_state.gallery_state:
323
+ st.session_state.gallery_state[prompt_id] = 'graph'
324
 
325
  # add safety check for some prompts
326
  safety_check = True
 
336
  st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
337
  safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
338
 
339
+ print('current state: ', st.session_state.gallery_state[prompt_id])
340
+
341
+ if st.session_state.gallery_state[prompt_id] == 'graph':
342
+ if safety_check:
343
+ # if subset == 'Selected Only' and 'selected_dict' in st.session_state:
344
+ # # try:
345
+ # items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
346
+ # self.gallery_mode(prompt_id, items)
347
+ # # except:
348
+ # # st.warning('No selected images found')
349
+ # else:
350
  self.graph_mode(prompt_id, items)
351
+ with subset_selector[1]:
352
+ # if st.session_state.gallery_state[prompt_id] == 'graph':
353
+ # subset = st.selectbox('Select a subset', ['All', 'Selected Only'], index=0, key=f'subset_{tag}')
354
+ has_selection = False
355
+ try:
356
+ if len(st.session_state.selected_dict.get(prompt_id, [])) > 0:
357
+ has_selection = True
358
+ except:
359
+ pass
360
+
361
+ if has_selection:
362
+ checkout = st.button('Check out selections', use_container_width=True, type='primary')
363
+ if checkout:
364
+ print('checkout')
365
+
366
+ st.session_state.gallery_state[prompt_id] = 'gallery'
367
+ print(st.session_state.gallery_state[prompt_id])
368
+ st.experimental_rerun()
369
+ else:
370
+ st.write('Select images you like below 👇')
371
+
372
+ elif st.session_state.gallery_state[prompt_id] == 'gallery':
373
+ items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(
374
+ drop=True)
375
+ self.gallery_mode(prompt_id, items)
376
+
377
+ with subset_selector[1]:
378
+ state_operations = st.columns([1, 1])
379
+ with state_operations[0]:
380
+ back = st.button('Back', use_container_width=True)
381
+ if back:
382
+ st.session_state.gallery_state[prompt_id] = 'graph'
383
+ st.experimental_rerun()
384
+
385
+ with state_operations[1]:
386
+ forward = st.button('Check out', use_container_width=True, type='primary', on_click=self.submit_actions, args=('Continue', prompt_id))
387
+ if forward:
388
+ switch_page('ranking')
389
+
390
+
391
+
392
+ # else:
393
+ # st.button('Proceed', use_container_width=True)
394
 
395
  try:
396
  self.sidebar(items, prompt_id, note)
 
455
  infos_df = infos_df.rename(index={'model_name': 'Model', 'modelVersion_name': 'Version', 'model_download_count': 'Downloads', 'clip_score': 'Clip Score', 'mcos_score': 'mcos Score', 'nsfw_score': 'NSFW Score'})
456
  st.table(infos_df)
457
 
 
 
458
  # for info in infos:
459
  # st.write(f"**{info}**:")
460
  # st.write(item[info])
 
462
  else:
463
  st.info('Please click on an image to show')
464
 
 
 
465
  def gallery_mode(self, prompt_id, items):
466
  items, info, col_num = self.selection_panel(items)
467
 
 
484
  # if prompt:
485
  # switch_page("ranking")
486
 
487
+ # with st.form(key=f'{prompt_id}'):
488
  # buttons = st.columns([1, 1, 1])
489
+ # buttons_space = st.columns([1, 1, 1])
490
+ gallery_space = st.empty()
491
+
492
+ # with buttons_space[0]:
493
+ # continue_btn = st.button('Proceed selections to ranking', use_container_width=True, type='primary')
494
+ # if continue_btn:
495
+ # # self.submit_actions('Continue', prompt_id)
496
+ # switch_page("ranking")
497
+ #
498
+ # with buttons_space[1]:
499
+ # deselect_btn = st.button('Deselect All', use_container_width=True)
500
+ # if deselect_btn:
501
+ # self.submit_actions('Deselect', prompt_id)
502
+ #
503
+ # with buttons_space[2]:
504
+ # refresh_btn = st.button('Refresh', on_click=gallery_space.empty, use_container_width=True)
505
 
506
+ with gallery_space.container():
507
+ with st.spinner('Loading images...'):
508
+ self.gallery_standard(items, col_num, info)
509
 
510
  # st.info("Don't forget to scroll back to top and click the 'Confirm Selection' button to save your selection!!!")
511
 
 
 
512
  def submit_actions(self, status, prompt_id):
513
  # remove counter from session state
514
  # st.session_state.pop('counter', None)
 
532
  st.session_state.selected_dict[prompt_id].append(int(keys[2]))
533
  # switch_page("ranking")
534
  print(st.session_state.selected_dict, 'continue')
535
+ # st.experimental_rerun()
536
 
537
  def dynamic_weight(self, prompt_id, items, method='Grid Search'):
538
  selected = items[
 
715
  roster, promptBook, images_ds = load_hf_dataset()
716
  # print(promptBook.columns)
717
 
718
+ # # initialize selected_dict
719
+ # if 'selected_dict' not in st.session_state:
720
+ # st.session_state['selected_dict'] = {}
721
 
722
  app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
723
  app.app()
pages/Summary.py CHANGED
@@ -128,11 +128,11 @@ class DashboardApp:
128
  st.image(image, use_column_width=True)
129
 
130
  def score_calculator(self, results, db_table):
131
- # sort results by battle time
132
- results = sorted(results, key=lambda x: x['battletime'])
133
-
134
  modelVersion_standings = {}
135
  if db_table == 'battle_results':
 
 
 
136
  for record in results:
137
  modelVersion_standings[record['winner']] = modelVersion_standings.get(record['winner'], 0) + 1
138
 
 
128
  st.image(image, use_column_width=True)
129
 
130
  def score_calculator(self, results, db_table):
 
 
 
131
  modelVersion_standings = {}
132
  if db_table == 'battle_results':
133
+ # sort results by battle time
134
+ results = sorted(results, key=lambda x: x['battletime'])
135
+
136
  for record in results:
137
  modelVersion_standings[record['winner']] = modelVersion_standings.get(record['winner'], 0) + 1
138