Ricercar commited on
Commit
b776852
1 Parent(s): 0ae586e

beta testing for new gallery page

Browse files
Files changed (2) hide show
  1. Home.py +11 -3
  2. pages/Gallery.py +44 -32
Home.py CHANGED
@@ -42,12 +42,20 @@ def info():
42
  with st.sidebar:
43
  st.write('## About')
44
 
 
 
 
 
 
 
 
 
45
  st.write(
46
- "This is an web application to collect personal preference to images synthesised by generative models fine-tuned on stable diffusion. \
47
- **You might consider it as a tool for quickly digging out the most suitable text-to-image generation model for you from [civitai](https://civitai.com/).**"
48
  )
 
49
  st.write(
50
- "After you picking images from gallery page, and ranking them in the ranking page, you will be able to see a dashboard showing your preferred models in the summary page, **with download links of the models ready to use in [Automatic1111 webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)!**"
51
  )
52
 
53
 
 
42
  with st.sidebar:
43
  st.write('## About')
44
 
45
+ # st.write(
46
+ # "This is an web application to collect personal preference to images synthesised by generative models fine-tuned on stable diffusion. \
47
+ # **You might consider it as a tool for quickly digging out the most suitable text-to-image generation model for you from [civitai](https://civitai.com/).**"
48
+ # )
49
+ # st.write(
50
+ # "After you picking images from gallery page, and ranking them in the ranking page, you will be able to see a dashboard showing your preferred models in the summary page, **with download links of the models ready to use in [Automatic1111 webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)!**"
51
+ # )
52
+
53
  st.write(
54
+ "This is a web application for individual users to quickly dig out the most suitable text-to-image generation model from civitai. Our research aims to understand personal preference to images synthesized by generative models fine-tuned on stable diffusion and you can contribute by playing with this tool and giving us your feedback! "
 
55
  )
56
+
57
  st.write(
58
+ "After picking images you liked from Gallery and a battle-mode Ranking Contest, a summary dashboard will be presented indicating your preferred models with download links ready to be deployed in Webui !"
59
  )
60
 
61
 
pages/Gallery.py CHANGED
@@ -12,6 +12,7 @@ from datasets import load_dataset, Dataset, load_from_disk
12
  from huggingface_hub import login
13
  from streamlit_agraph import agraph, Node, Edge, Config
14
  from streamlit_extras.switch_page_button import switch_page
 
15
  from sklearn.svm import LinearSVC
16
 
17
  SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'msq_score', 'pop': 'model_download_count'}
@@ -226,6 +227,7 @@ class GalleryApp:
226
 
227
  # items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
228
 
 
229
 
230
  # show source
231
  if isinstance(note, str):
@@ -282,36 +284,46 @@ class GalleryApp:
282
 
283
  subset_selector = st.columns([3, 1])
284
  with subset_selector[0]:
285
- selected_prompt = st.selectbox('Select prompt', prompts, index=3)
 
286
  with subset_selector[1]:
287
- subset = st.selectbox('Select a subset', ['All', 'Selected Only'], index=0, key=f'subset_{selected_prompt}')
288
- items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
289
- prompt_id = items['prompt_id'].unique()[0]
290
- note = items['note'].unique()[0]
291
-
292
- # add safety check for some prompts
293
- safety_check = True
294
- unsafe_prompts = {}
295
- # initialize unsafe prompts
296
- for prompt_tag in prompt_tags:
297
- unsafe_prompts[prompt_tag] = []
298
- # manually add unsafe prompts
299
- unsafe_prompts['world knowledge'] = [83]
300
- unsafe_prompts['abstract'] = [1, 3]
301
-
302
- if int(prompt_id.item()) in unsafe_prompts[tag]:
303
- st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
304
- safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
305
-
306
- if safety_check:
307
-
308
- # if subset == 'Selected Only' and 'selected_dict' in st.session_state:
309
- # items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
310
- # self.gallery_mode(prompt_id, items)
311
- # else:
312
- self.graph_mode(prompt_id, items)
313
-
314
- self.sidebar(items, prompt_id, note)
 
 
 
 
 
 
 
 
 
315
 
316
  def graph_mode(self, prompt_id, items):
317
  graph_cols = st.columns([3, 1])
@@ -397,9 +409,9 @@ class GalleryApp:
397
  # with dynamic_weight_panel[i]:
398
  # btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items, method))
399
 
400
- prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}", disabled=False, key=f'{prompt_id}')
401
- if prompt:
402
- switch_page("ranking")
403
 
404
  with st.form(key=f'{prompt_id}'):
405
  # buttons = st.columns([1, 1, 1])
 
12
  from huggingface_hub import login
13
  from streamlit_agraph import agraph, Node, Edge, Config
14
  from streamlit_extras.switch_page_button import switch_page
15
+ from streamlit_extras.no_default_selectbox import selectbox
16
  from sklearn.svm import LinearSVC
17
 
18
  SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'msq_score', 'pop': 'model_download_count'}
 
227
 
228
  # items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
229
 
230
+ st.title('Model Visualization and Retrieval')
231
 
232
  # show source
233
  if isinstance(note, str):
 
284
 
285
  subset_selector = st.columns([3, 1])
286
  with subset_selector[0]:
287
+ # selected_prompt = st.selectbox('Select prompt', prompts, index=3)
288
+ selected_prompt = selectbox('Select prompt', prompts, key=f'prompt_{tag}', no_selection_label='---')
289
  with subset_selector[1]:
290
+ subset = st.selectbox('Select a subset', ['All', 'Selected Only'], index=0, key=f'subset_{tag}')
291
+
292
+ if selected_prompt is None:
293
+ st.markdown(':orange[Please select a prompt above👆]')
294
+ else:
295
+ items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
296
+ prompt_id = items['prompt_id'].unique()[0]
297
+ note = items['note'].unique()[0]
298
+
299
+ # add safety check for some prompts
300
+ safety_check = True
301
+ unsafe_prompts = {}
302
+ # initialize unsafe prompts
303
+ for prompt_tag in prompt_tags:
304
+ unsafe_prompts[prompt_tag] = []
305
+ # manually add unsafe prompts
306
+ unsafe_prompts['world knowledge'] = [83]
307
+ unsafe_prompts['abstract'] = [1, 3]
308
+
309
+ if int(prompt_id.item()) in unsafe_prompts[tag]:
310
+ st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
311
+ safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
312
+
313
+ if safety_check:
314
+
315
+ if subset == 'Selected Only' and 'selected_dict' in st.session_state:
316
+ # try:
317
+ items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
318
+ self.gallery_mode(prompt_id, items)
319
+ # except:
320
+ # st.warning('No selected images found')
321
+ else:
322
+ self.graph_mode(prompt_id, items)
323
+ try:
324
+ self.sidebar(items, prompt_id, note)
325
+ except:
326
+ pass
327
 
328
  def graph_mode(self, prompt_id, items):
329
  graph_cols = st.columns([3, 1])
 
409
  # with dynamic_weight_panel[i]:
410
  # btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items, method))
411
 
412
+ # prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}", disabled=False, key=f'{prompt_id}')
413
+ # if prompt:
414
+ # switch_page("ranking")
415
 
416
  with st.form(key=f'{prompt_id}'):
417
  # buttons = st.columns([1, 1, 1])