Ricercar commited on
Commit
5bfac4b
1 Parent(s): 6963450

add new custom weighting mode

Browse files
Files changed (1) hide show
  1. app.py +190 -82
app.py CHANGED
@@ -10,6 +10,8 @@ from datasets import load_dataset, Dataset, load_from_disk
10
  from huggingface_hub import login
11
  import os
12
  import requests
 
 
13
 
14
  SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'avg_rank', 'pop': 'model_download_count'}
15
 
@@ -57,7 +59,7 @@ class GalleryApp:
57
  with cols[j]:
58
  # show image
59
  image = st.session_state.images[items.iloc[idx+j]['row_idx'].item()]['image']
60
- # image = list(st.session_state.images.skip(items.iloc[idx+j]['row_idx'].item()).take(1))[0]['image']
61
  st.image(image,
62
  use_column_width=True,
63
  )
@@ -75,75 +77,22 @@ class GalleryApp:
75
  # with containers[row_idx+1]:
76
  # st.image(image, use_column_width=True)
77
 
78
- def app(self):
79
- st.title('Model Coffer Gallery')
80
- st.write('This is a gallery of images generated by the models in the Model Coffer')
81
-
82
- with st.sidebar:
83
- prompt_tags = self.promptBook['tag'].unique()
84
- # sort tags by alphabetical order
85
- prompt_tags = np.sort(prompt_tags)[::-1]
86
-
87
- tag = st.selectbox('Select a tag', prompt_tags)
88
-
89
- items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
90
-
91
- original_prompts = np.sort(items['prompt'].unique())[::-1]
92
-
93
- # remove the first four items in the prompt, which are mostly the same
94
- if tag != 'abstract':
95
- prompts = [', '.join(x.split(', ')[4:]) for x in original_prompts]
96
- prompt = st.selectbox('Select prompt', prompts)
97
-
98
- idx = prompts.index(prompt)
99
- prompt_full = ', '.join(original_prompts[idx].split(', ')[:4]) + ', ' + prompt
100
- else:
101
- prompt_full = st.selectbox('Select prompt', original_prompts)
102
-
103
- prompt_id = items[items['prompt'] == prompt_full]['prompt_id'].unique()[0]
104
- items = items[items['prompt_id'] == prompt_id].reset_index(drop=True)
105
-
106
- st.write('**Prompt ID**')
107
- st.caption(f"{prompt_id}")
108
- st.write('**Prompt**')
109
- st.caption(f"{items['prompt'][0]}")
110
- st.write('**Negative Prompt**')
111
- st.caption(f"{items['negativePrompt'][0]}")
112
- st.write('**Sampler**')
113
- st.caption(f"{items['sampler'][0]}")
114
- st.write('**cfgScale**')
115
- st.caption(f"{items['cfgScale'][0]}")
116
- st.write('**Size**')
117
- st.caption(f"width: {items['size'][0].split('x')[0]}, height: {items['size'][0].split('x')[1]}")
118
- st.write('**Seed**')
119
- st.caption(f"{items['seed'][0]}")
120
-
121
- # # for tag as civitai, add civitai reference
122
- # if tag == 'civitai':
123
- # st.write('**Reference**')
124
- #
125
- # res = requests.get(f'https://civitai.com/images', params={'post_id': prompt_id})
126
- # st.write(res)
127
- # image_url = res.json()['items'][0]['url']
128
- # st.image(image_url, use_column_width=True)
129
-
130
- # with images:
131
- # selecters = st.columns([2, 1, 2, 0.5])
132
  selecters = st.columns([4, 1, 1])
133
 
134
  with selecters[0]:
135
- # # sort_by = st.selectbox('Sort by', items.columns[11: -1])
136
- # sort_by = st.selectbox('Sort by', ['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
137
- # 'modelVersion_name', 'modelVersion_id'])
138
- print(items.columns)
139
  types = st.columns([1, 3])
140
  with types[0]:
141
  sort_type = st.selectbox('Sort by', ['IDs and Names', 'Scores'])
142
  with types[1]:
143
  if sort_type == 'IDs and Names':
144
- sort_by = st.selectbox('Sort by', ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id'], label_visibility='hidden')
 
 
145
  elif sort_type == 'Scores':
146
- sort_by = st.multiselect('Sort by', ['clip_score', 'avg_rank', 'popularity'], label_visibility='hidden', default=['clip_score', 'avg_rank', 'popularity'])
 
 
147
  # process sort_by to map to the column name
148
 
149
  if len(sort_by) == 3:
@@ -172,18 +121,24 @@ class GalleryApp:
172
  items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
173
 
174
  with selecters[2]:
175
- filter = st.selectbox('Filter', ['All', 'Checked', 'Unchecked'])
176
- if filter == 'Checked':
177
- items = items[items['checked'] is True].reset_index(drop=True)
178
- elif filter == 'Unchecked':
179
- items = items[items['checked'] is False].reset_index(drop=True)
 
 
 
 
 
 
180
 
181
  info = st.multiselect('Show Info',
182
  ['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
183
- 'modelVersion_name', 'modelVersion_id', 'clip+rank', 'clip+pop', 'rank+pop', 'clip+rank+pop'],
 
184
  default=sort_by)
185
 
186
- print('info', info)
187
  # add one annotation
188
  mentioned_scores = []
189
  for i in info:
@@ -193,20 +148,173 @@ class GalleryApp:
193
  if SCORE_NAME_MAPPING[m] not in mentioned_scores:
194
  mentioned_scores.append(SCORE_NAME_MAPPING[m])
195
  if len(mentioned_scores) > 0:
196
- st.write(f"**Note: ** The scores {mentioned_scores} are normalized to [0, 1] for each score type, and then added together. The higher the score, the better the model.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
200
 
201
- with st.form(key=f'{prompt_id}', clear_on_submit=False):
202
- buttons = st.columns([1, 1, 1])
203
- with buttons[0]:
204
- submit = st.form_submit_button('Save selections', on_click=self.save_checked, use_container_width=True, type='primary')
205
- with buttons[1]:
206
- submit = st.form_submit_button('Reset current prompt', on_click=self.reset_current_prompt, kwargs={'prompt_id': prompt_id} , use_container_width=True)
207
- with buttons[2]:
208
- submit = st.form_submit_button('Reset all selections', on_click=self.reset_all, use_container_width=True)
209
- self.gallery_standard(items, col_num, info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
  def reset_current_prompt(self, prompt_id):
212
  # reset current prompt
@@ -223,10 +331,6 @@ class GalleryApp:
223
  dataset = load_dataset('NYUSHPRP/ModelCofferMetadata', split='train')
224
  # get checked images
225
  checked_info = self.promptBook['checked']
226
- # print('checked_info: ', checked_info)
227
- # for d in checked_info:
228
- # if d is True:
229
- # print('checked')
230
 
231
  if 'checked' in dataset.column_names:
232
  dataset = dataset.remove_columns('checked')
@@ -254,6 +358,10 @@ if __name__ == '__main__':
254
  if 'checked' not in st.session_state.promptBook.columns:
255
  st.session_state.promptBook.loc[:, 'checked'] = False
256
 
 
 
 
 
257
  st.session_state.images = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
258
  # st.session_state.images = load_dataset('NYUSHPRP/ModelCofferPromptBook', split='train', streaming=True)
259
  print(st.session_state.images)
 
10
  from huggingface_hub import login
11
  import os
12
  import requests
13
+ from bs4 import BeautifulSoup
14
+ import re
15
 
16
  SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'avg_rank', 'pop': 'model_download_count'}
17
 
 
59
  with cols[j]:
60
  # show image
61
  image = st.session_state.images[items.iloc[idx+j]['row_idx'].item()]['image']
62
+
63
  st.image(image,
64
  use_column_width=True,
65
  )
 
77
  # with containers[row_idx+1]:
78
  # st.image(image, use_column_width=True)
79
 
80
+ def selection_panel(self, items):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  selecters = st.columns([4, 1, 1])
82
 
83
  with selecters[0]:
 
 
 
 
84
  types = st.columns([1, 3])
85
  with types[0]:
86
  sort_type = st.selectbox('Sort by', ['IDs and Names', 'Scores'])
87
  with types[1]:
88
  if sort_type == 'IDs and Names':
89
+ sort_by = st.selectbox('Sort by',
90
+ ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id'],
91
+ label_visibility='hidden')
92
  elif sort_type == 'Scores':
93
+ sort_by = st.multiselect('Sort by', ['clip_score', 'avg_rank', 'popularity'],
94
+ label_visibility='hidden',
95
+ default=['clip_score', 'avg_rank', 'popularity'])
96
  # process sort_by to map to the column name
97
 
98
  if len(sort_by) == 3:
 
121
  items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
122
 
123
  with selecters[2]:
124
+ filter = st.selectbox('Filter', ['Safe', 'All', 'Unsafe'])
125
+ print('filter', filter)
126
+ # initialize unsafe_modelVersion_ids
127
+ if filter == 'Safe':
128
+ # return checked items
129
+ items = items[items['checked'] == False].reset_index(drop=True)
130
+
131
+ elif filter == 'Unsafe':
132
+ # return unchecked items
133
+ items = items[items['checked'] == True].reset_index(drop=True)
134
+ print(items)
135
 
136
  info = st.multiselect('Show Info',
137
  ['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
138
+ 'modelVersion_name', 'modelVersion_id', 'clip+rank', 'clip+pop', 'rank+pop',
139
+ 'clip+rank+pop'],
140
  default=sort_by)
141
 
 
142
  # add one annotation
143
  mentioned_scores = []
144
  for i in info:
 
148
  if SCORE_NAME_MAPPING[m] not in mentioned_scores:
149
  mentioned_scores.append(SCORE_NAME_MAPPING[m])
150
  if len(mentioned_scores) > 0:
151
+ st.info(
152
+ f"**Note:** The scores {mentioned_scores} are normalized to [0, 1] for each score type, and then added together. The higher the score, the better the model.")
153
+
154
+ col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
155
+
156
+ return items, info, col_num
157
+
158
+
159
+ def selection_panel_2(self, items):
160
+ selecters = st.columns([1, 5])
161
+
162
+ with selecters[0]:
163
+ sort_type = st.selectbox('Sort by', ['IDs and Names', 'Scores'])
164
+ if sort_type == 'Scores':
165
+ sort_by = 'weighted_score_sum'
166
+
167
+ with selecters[1]:
168
+ if sort_type == 'IDs and Names':
169
+ sub_selecters = st.columns([3, 1, 1])
170
+ with sub_selecters[0]:
171
+ sort_by = st.selectbox('Sort by',
172
+ ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id'],
173
+ label_visibility='hidden')
174
+
175
+ continue_idx = 1
176
+
177
+ else:
178
+ sub_selecters = st.columns([1, 1, 1, 1, 1])
179
+
180
+ with sub_selecters[0]:
181
+ clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=1.0, step=0.1)
182
+ with sub_selecters[1]:
183
+ rank_weight = st.number_input('Rank Score Weight', min_value=-100.0, max_value=100.0, value=1.0, step=0.1)
184
+ with sub_selecters[2]:
185
+ pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=1.0, step=0.1)
186
+
187
+ items.loc[:, 'weighted_score_sum'] = items['norm_clip'] * clip_weight + items['avg_rank'] * rank_weight + items[
188
+ 'norm_pop'] * pop_weight
189
+
190
+ continue_idx = 3
191
+
192
+ with sub_selecters[continue_idx]:
193
+ order = st.selectbox('Order', ['Ascending', 'Descending'], index=1 if sort_type == 'Scores' else 0)
194
+ if order == 'Ascending':
195
+ order = True
196
+ else:
197
+ order = False
198
+
199
+ items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
200
+
201
+ with sub_selecters[continue_idx+1]:
202
+ filter = st.selectbox('Filter', ['Safe', 'All', 'Unsafe'])
203
+ print('filter', filter)
204
+ # initialize unsafe_modelVersion_ids
205
+ if filter == 'Safe':
206
+ # return checked items
207
+ items = items[items['checked'] == False].reset_index(drop=True)
208
+
209
+ elif filter == 'Unsafe':
210
+ # return unchecked items
211
+ items = items[items['checked'] == True].reset_index(drop=True)
212
+ print(items)
213
+
214
+ info = st.multiselect('Show Info',
215
+ ['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
216
+ 'modelVersion_name', 'modelVersion_id', 'clip+rank', 'clip+pop', 'rank+pop',
217
+ 'clip+rank+pop', 'weighted_score_sum'],
218
+ default=sort_by)
219
 
220
+ # add one annotation
221
+ mentioned_scores = []
222
+ for i in info:
223
+ if '+' in i:
224
+ mentioned = i.split('+')
225
+ for m in mentioned:
226
+ if SCORE_NAME_MAPPING[m] not in mentioned_scores:
227
+ mentioned_scores.append(SCORE_NAME_MAPPING[m])
228
+ if len(mentioned_scores) > 0:
229
+ st.info(
230
+ f"**Note:** The scores {mentioned_scores} are normalized to [0, 1] for each score type, and then added together. The higher the score, the better the model.")
231
 
232
  col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
233
 
234
+ return items, info, col_num
235
+
236
+ def app(self):
237
+ st.title('Model Coffer Gallery')
238
+ st.write('This is a gallery of images generated by the models in the Model Coffer')
239
+
240
+ with st.sidebar:
241
+ prompt_tags = self.promptBook['tag'].unique()
242
+ # sort tags by alphabetical order
243
+ prompt_tags = np.sort(prompt_tags)[::-1]
244
+
245
+ tag = st.selectbox('Select a tag', prompt_tags)
246
+
247
+ items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
248
+
249
+ original_prompts = np.sort(items['prompt'].unique())[::-1]
250
+
251
+ # remove the first four items in the prompt, which are mostly the same
252
+ if tag != 'abstract':
253
+ prompts = [', '.join(x.split(', ')[4:]) for x in original_prompts]
254
+ prompt = st.selectbox('Select prompt', prompts)
255
+
256
+ idx = prompts.index(prompt)
257
+ prompt_full = ', '.join(original_prompts[idx].split(', ')[:4]) + ', ' + prompt
258
+ else:
259
+ prompt_full = st.selectbox('Select prompt', original_prompts)
260
+
261
+ prompt_id = items[items['prompt'] == prompt_full]['prompt_id'].unique()[0]
262
+ items = items[items['prompt_id'] == prompt_id].reset_index(drop=True)
263
+
264
+ # show image metadata
265
+ image_metadatas = ['prompt_id', 'prompt', 'negativePrompt', 'sampler', 'cfgScale', 'size', 'seed']
266
+ for key in image_metadatas:
267
+ label = ' '.join(key.split('_')).capitalize()
268
+ st.write(f"**{label}**")
269
+ if items[key][0] == ' ':
270
+ st.write('`None`')
271
+ else:
272
+ st.caption(f"{items[key][0]}")
273
+
274
+ # for tag as civitai, add civitai reference
275
+ if tag == 'civitai':
276
+ try:
277
+ st.write('**Civitai Reference**')
278
+ res = requests.get(f'https://civitai.com/images/{prompt_id.item()}')
279
+ # st.write(res.text)
280
+ soup = BeautifulSoup(res.text, 'html.parser')
281
+ image_section = soup.find('div', {'class': 'mantine-12rlksp'})
282
+ image_url = image_section.find('img')['src']
283
+ st.image(image_url, use_column_width=True)
284
+ except:
285
+ pass
286
+
287
+
288
+ # add safety check for some prompts
289
+ safety_check = True
290
+ unsafe_prompts = {}
291
+ # initialize unsafe prompts
292
+ for prompt_tag in prompt_tags:
293
+ unsafe_prompts[prompt_tag] = []
294
+ # manually add unsafe prompts
295
+ unsafe_prompts['civitai'] = [375790, 366222, 295008, 256477]
296
+ unsafe_prompts['people'] = [53]
297
+ unsafe_prompts['art'] = [23]
298
+ unsafe_prompts['abstract'] = [10, 12]
299
+
300
+ if int(prompt_id.item()) in unsafe_prompts[tag]:
301
+ st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
302
+ safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.')
303
+
304
+ if safety_check:
305
+ items, info, col_num = self.selection_panel_2(items)
306
+ # self.gallery_standard(items, col_num, info)
307
+
308
+ with st.form(key=f'{prompt_id}', clear_on_submit=False):
309
+ buttons = st.columns([1, 1, 1])
310
+ with buttons[0]:
311
+ submit = st.form_submit_button('Save selections', on_click=self.save_checked, use_container_width=True, type='primary')
312
+ with buttons[1]:
313
+ submit = st.form_submit_button('Reset current prompt', on_click=self.reset_current_prompt, kwargs={'prompt_id': prompt_id} , use_container_width=True)
314
+ with buttons[2]:
315
+ submit = st.form_submit_button('Reset all selections', on_click=self.reset_all, use_container_width=True)
316
+
317
+ self.gallery_standard(items, col_num, info)
318
 
319
  def reset_current_prompt(self, prompt_id):
320
  # reset current prompt
 
331
  dataset = load_dataset('NYUSHPRP/ModelCofferMetadata', split='train')
332
  # get checked images
333
  checked_info = self.promptBook['checked']
 
 
 
 
334
 
335
  if 'checked' in dataset.column_names:
336
  dataset = dataset.remove_columns('checked')
 
358
  if 'checked' not in st.session_state.promptBook.columns:
359
  st.session_state.promptBook.loc[:, 'checked'] = False
360
 
361
+ # add 'custom_score_weights' column to promptBook if not exist
362
+ if 'weighted_score_sum' not in st.session_state.promptBook.columns:
363
+ st.session_state.promptBook.loc[:, 'weighted_score_sum'] = 0
364
+
365
  st.session_state.images = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
366
  # st.session_state.images = load_dataset('NYUSHPRP/ModelCofferPromptBook', split='train', streaming=True)
367
  print(st.session_state.images)