Ricercar commited on
Commit
fa4bfff
1 Parent(s): d4d6074

Update Gallery.py

Browse files
Files changed (1) hide show
  1. pages/Gallery.py +9 -11
pages/Gallery.py CHANGED
@@ -82,11 +82,11 @@ class GalleryApp:
82
  sub_selecters = st.columns([1, 1, 1, 1])
83
 
84
  with sub_selecters[0]:
85
- clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[0], step=0.1, help='the weight for normalized clip score')
86
  with sub_selecters[1]:
87
- mcos_weight = st.number_input('Dissimilarity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[1], step=0.1, help='the weight for m(eam) s(imilarity) q(antile) score for measuring distinctiveness')
88
  with sub_selecters[2]:
89
- pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[2], step=0.1, help='the weight for normalized popularity score')
90
 
91
  items.loc[:, 'weighted_score_sum'] = round(items[f'norm_clip'] * clip_weight + items[f'norm_mcos'] * mcos_weight + items[
92
  'norm_pop'] * pop_weight, 4)
@@ -94,13 +94,13 @@ class GalleryApp:
94
  continue_idx = 3
95
 
96
  # save latest weights
97
- st.session_state.score_weights[0] = clip_weight
98
- st.session_state.score_weights[1] = mcos_weight
99
- st.session_state.score_weights[2] = pop_weight
100
 
101
  # select threshold
102
  with sub_selecters[continue_idx]:
103
- nsfw_threshold = st.number_input('NSFW Score Threshold', min_value=0.0, max_value=1.0, value=st.session_state.score_weights[3], step=0.01, help='Only show models with nsfw score lower than this threshold, set 1.0 to show all images')
104
  items = items[items['norm_nsfw'] <= nsfw_threshold].reset_index(drop=True)
105
 
106
  # save latest threshold
@@ -214,6 +214,7 @@ class GalleryApp:
214
  st.write('This is a gallery of images generated by the models')
215
 
216
  prompt_tags, tag, prompt_id, items = self.sidebar()
 
217
 
218
  # add safety check for some prompts
219
  safety_check = True
@@ -223,16 +224,13 @@ class GalleryApp:
223
  unsafe_prompts[prompt_tag] = []
224
  # manually add unsafe prompts
225
  unsafe_prompts['world knowledge'] = [83]
226
- # unsafe_prompts['art'] = [23]
227
  unsafe_prompts['abstract'] = [1, 3]
228
- # unsafe_prompts['food'] = [34]
229
 
230
  if int(prompt_id.item()) in unsafe_prompts[tag]:
231
  st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
232
- safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'{prompt_id}')
233
 
234
  if safety_check:
235
- items, info, col_num = self.selection_panel(items)
236
 
237
  if 'selected_dict' in st.session_state:
238
  # st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
 
82
  sub_selecters = st.columns([1, 1, 1, 1])
83
 
84
  with sub_selecters[0]:
85
+ clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=1.0, step=0.1, help='the weight for normalized clip score')
86
  with sub_selecters[1]:
87
+ mcos_weight = st.number_input('Dissimilarity Weight', min_value=-100.0, max_value=100.0, value=0.8, step=0.1, help='the weight for m(eam) s(imilarity) q(antile) score for measuring distinctiveness')
88
  with sub_selecters[2]:
89
+ pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=0.2, step=0.1, help='the weight for normalized popularity score')
90
 
91
  items.loc[:, 'weighted_score_sum'] = round(items[f'norm_clip'] * clip_weight + items[f'norm_mcos'] * mcos_weight + items[
92
  'norm_pop'] * pop_weight, 4)
 
94
  continue_idx = 3
95
 
96
  # save latest weights
97
+ st.session_state.score_weights[0] = round(clip_weight, 2)
98
+ st.session_state.score_weights[1] = round(mcos_weight, 2)
99
+ st.session_state.score_weights[2] = round(pop_weight, 2)
100
 
101
  # select threshold
102
  with sub_selecters[continue_idx]:
103
+ nsfw_threshold = st.number_input('NSFW Score Threshold', min_value=0.0, max_value=1.0, value=0.8, step=0.01, help='Only show models with nsfw score lower than this threshold, set 1.0 to show all images')
104
  items = items[items['norm_nsfw'] <= nsfw_threshold].reset_index(drop=True)
105
 
106
  # save latest threshold
 
214
  st.write('This is a gallery of images generated by the models')
215
 
216
  prompt_tags, tag, prompt_id, items = self.sidebar()
217
+ items, info, col_num = self.selection_panel(items)
218
 
219
  # add safety check for some prompts
220
  safety_check = True
 
224
  unsafe_prompts[prompt_tag] = []
225
  # manually add unsafe prompts
226
  unsafe_prompts['world knowledge'] = [83]
 
227
  unsafe_prompts['abstract'] = [1, 3]
 
228
 
229
  if int(prompt_id.item()) in unsafe_prompts[tag]:
230
  st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
231
+ safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
232
 
233
  if safety_check:
 
234
 
235
  if 'selected_dict' in st.session_state:
236
  # st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))