Ricercar commited on
Commit
c9b39a0
1 Parent(s): b776852

beta development for new gallery view

Browse files
Files changed (3) hide show
  1. pages/Gallery.py +57 -47
  2. pages/Ranking.py +2 -1
  3. requirements.txt +1 -0
pages/Gallery.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import requests
3
 
4
  import altair as alt
 
5
  import numpy as np
6
  import pandas as pd
7
  import streamlit as st
@@ -263,7 +264,6 @@ class GalleryApp:
263
 
264
  # return prompt_tags, tag, prompt_id, items
265
 
266
-
267
  def app(self):
268
  # st.title('Model Visualization and Retrieval')
269
  # st.write('This is a gallery of images generated by the models')
@@ -273,53 +273,60 @@ class GalleryApp:
273
  # sort tags by alphabetical order
274
  prompt_tags = np.sort(prompt_tags)[::1].tolist()
275
 
276
- tabs = st.tabs(prompt_tags)
277
- with st.spinner('Loading...'):
278
- for i in range(len(prompt_tags)):
279
- with tabs[i]:
280
- tag = prompt_tags[i]
281
- items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
282
 
283
- prompts = np.sort(items['prompt'].unique())[::1]
284
 
285
- subset_selector = st.columns([3, 1])
286
- with subset_selector[0]:
287
- # selected_prompt = st.selectbox('Select prompt', prompts, index=3)
288
- selected_prompt = selectbox('Select prompt', prompts, key=f'prompt_{tag}', no_selection_label='---')
289
- with subset_selector[1]:
290
- subset = st.selectbox('Select a subset', ['All', 'Selected Only'], index=0, key=f'subset_{tag}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
- if selected_prompt is None:
293
- st.markdown(':orange[Please select a prompt above👆]')
294
- else:
295
- items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
296
- prompt_id = items['prompt_id'].unique()[0]
297
- note = items['note'].unique()[0]
298
-
299
- # add safety check for some prompts
300
- safety_check = True
301
- unsafe_prompts = {}
302
- # initialize unsafe prompts
303
- for prompt_tag in prompt_tags:
304
- unsafe_prompts[prompt_tag] = []
305
- # manually add unsafe prompts
306
- unsafe_prompts['world knowledge'] = [83]
307
- unsafe_prompts['abstract'] = [1, 3]
308
-
309
- if int(prompt_id.item()) in unsafe_prompts[tag]:
310
- st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
311
- safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
312
-
313
- if safety_check:
314
-
315
- if subset == 'Selected Only' and 'selected_dict' in st.session_state:
316
- # try:
317
- items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
318
- self.gallery_mode(prompt_id, items)
319
- # except:
320
- # st.warning('No selected images found')
321
- else:
322
- self.graph_mode(prompt_id, items)
323
  try:
324
  self.sidebar(items, prompt_id, note)
325
  except:
@@ -383,6 +390,8 @@ class GalleryApp:
383
  infos_df = infos_df.rename(index={'model_name': 'Model', 'modelVersion_name': 'Version', 'model_download_count': 'Downloads', 'clip_score': 'Clip Score', 'mcos_score': 'mcos Score', 'nsfw_score': 'NSFW Score'})
384
  st.table(infos_df)
385
 
 
 
386
  # for info in infos:
387
  # st.write(f"**{info}**:")
388
  # st.write(item[info])
@@ -391,6 +400,7 @@ class GalleryApp:
391
  st.info('Please click on an image to show')
392
 
393
 
 
394
  def gallery_mode(self, prompt_id, items):
395
  items, info, col_num = self.selection_panel(items)
396
 
@@ -627,9 +637,9 @@ def load_tsne_coordinates(items):
627
 
628
  # print(tsne_df['modelVersion_id'].dtype)
629
 
630
- print('before merge:', items)
631
  items = items.merge(tsne_df, on=['modelVersion_id', 'prompt_id'], how='left')
632
- print('after merge:', items)
633
  return items
634
 
635
 
 
2
  import requests
3
 
4
  import altair as alt
5
+ import extra_streamlit_components as stx
6
  import numpy as np
7
  import pandas as pd
8
  import streamlit as st
 
264
 
265
  # return prompt_tags, tag, prompt_id, items
266
 
 
267
  def app(self):
268
  # st.title('Model Visualization and Retrieval')
269
  # st.write('This is a gallery of images generated by the models')
 
273
  # sort tags by alphabetical order
274
  prompt_tags = np.sort(prompt_tags)[::1].tolist()
275
 
276
+ # chosen_data = [stx.TabBarItemData(id=tag, title=tag, description='') for tag in prompt_tags]
277
+ # tag = stx.tab_bar(chosen_data, key='tag', default='food')
 
 
 
 
278
 
279
+ tag = st.radio('Select a tag', prompt_tags, index=5, horizontal=True, key='tag')
280
 
281
+ # tabs = st.tabs(prompt_tags)
282
+ # for i in range(len(prompt_tags)):
283
+ # with tabs[i]:
284
+ # tag = prompt_tags[i]
285
+ items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
286
+
287
+ prompts = np.sort(items['prompt'].unique())[::1]
288
+
289
+ subset_selector = st.columns([3, 1])
290
+ with subset_selector[0]:
291
+ # selected_prompt = st.selectbox('Select prompt', prompts, index=3)
292
+ selected_prompt = selectbox('Select prompt', prompts, key=f'prompt_{tag}', no_selection_label='---')
293
+ with subset_selector[1]:
294
+ subset = st.selectbox('Select a subset', ['All', 'Selected Only'], index=0, key=f'subset_{tag}')
295
+
296
+ if selected_prompt is None:
297
+ st.markdown(':orange[Please select a prompt above👆]')
298
+ st.write('**Feel free to navigate among tags and pages! Your selection will be saved within one log-in session.**')
299
+ else:
300
+ items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
301
+ prompt_id = items['prompt_id'].unique()[0]
302
+ note = items['note'].unique()[0]
303
+ print(prompt_id, note)
304
+
305
+ # add safety check for some prompts
306
+ safety_check = True
307
+ unsafe_prompts = {}
308
+ # initialize unsafe prompts
309
+ for prompt_tag in prompt_tags:
310
+ unsafe_prompts[prompt_tag] = []
311
+ # manually add unsafe prompts
312
+ unsafe_prompts['world knowledge'] = [83]
313
+ unsafe_prompts['abstract'] = [1, 3]
314
+
315
+ if int(prompt_id.item()) in unsafe_prompts[tag]:
316
+ st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
317
+ safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
318
+
319
+ if safety_check:
320
+
321
+ if subset == 'Selected Only' and 'selected_dict' in st.session_state:
322
+ # try:
323
+ items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
324
+ self.gallery_mode(prompt_id, items)
325
+ # except:
326
+ # st.warning('No selected images found')
327
+ else:
328
+ self.graph_mode(prompt_id, items)
329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  try:
331
  self.sidebar(items, prompt_id, note)
332
  except:
 
390
  infos_df = infos_df.rename(index={'model_name': 'Model', 'modelVersion_name': 'Version', 'model_download_count': 'Downloads', 'clip_score': 'Clip Score', 'mcos_score': 'mcos Score', 'nsfw_score': 'NSFW Score'})
391
  st.table(infos_df)
392
 
393
+ st.button('🎖️ Proceed selections to ranking', on_click=switch_page, args=("ranking",), use_container_width=True,)
394
+
395
  # for info in infos:
396
  # st.write(f"**{info}**:")
397
  # st.write(item[info])
 
400
  st.info('Please click on an image to show')
401
 
402
 
403
+
404
  def gallery_mode(self, prompt_id, items):
405
  items, info, col_num = self.selection_panel(items)
406
 
 
637
 
638
  # print(tsne_df['modelVersion_id'].dtype)
639
 
640
+ # print('before merge:', items)
641
  items = items.merge(tsne_df, on=['modelVersion_id', 'prompt_id'], how='left')
642
+ # print('after merge:', items)
643
  return items
644
 
645
 
pages/Ranking.py CHANGED
@@ -255,7 +255,8 @@ class RankingApp:
255
 
256
  def app(self):
257
  st.title('Personal Image Ranking')
258
- st.write('Here you can test out your selected images with any prompt you like.')
 
259
  # st.write(self.promptBook)
260
 
261
  # save the current progress to session state
 
255
 
256
  def app(self):
257
  st.title('Personal Image Ranking')
258
+ st.write('Here you can test out your selected images with any prompt you like. ')
259
+ st.caption("We might pair some other images that you haven't selected based on our evaluation matrix.")
260
  # st.write(self.promptBook)
261
 
262
  # save the current progress to session state
requirements.txt CHANGED
@@ -6,3 +6,4 @@ altair<5
6
  streamlit-vega-lite
7
  scikit-learn
8
  pymysql
 
 
6
  streamlit-vega-lite
7
  scikit-learn
8
  pymysql
9
+ extra_streamlit_components