Ricercar commited on
Commit
48bb321
1 Parent(s): a3efb89

add distabled chat message

Browse files
Archive/Gallery_archive_8_5.py DELETED
@@ -1,446 +0,0 @@
1
- import os
2
- import requests
3
-
4
- import altair as alt
5
- import numpy as np
6
- import pandas as pd
7
- import streamlit as st
8
-
9
- from bs4 import BeautifulSoup
10
- from datasets import load_dataset, Dataset, load_from_disk
11
- from huggingface_hub import login
12
- from streamlit_extras.switch_page_button import switch_page
13
- from sklearn.svm import LinearSVC
14
-
15
- SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'msq_score', 'pop': 'model_download_count'}
16
-
17
-
18
- class GalleryApp:
19
- def __init__(self, promptBook, images_ds):
20
- self.promptBook = promptBook
21
- self.images_ds = images_ds
22
-
23
- def gallery_standard(self, items, col_num, info):
24
- rows = len(items) // col_num + 1
25
- containers = [st.container() for _ in range(rows)]
26
- for idx in range(0, len(items), col_num):
27
- row_idx = idx // col_num
28
- with containers[row_idx]:
29
- cols = st.columns(col_num)
30
- for j in range(col_num):
31
- if idx + j < len(items):
32
- with cols[j]:
33
- # show image
34
- # image = self.images_ds[items.iloc[idx + j]['row_idx'].item()]['image']
35
- # image = f"https://modelcofferbucket.s3.us-east-2.amazonaws.com/{items.iloc[idx + j]['image_id']}.png"
36
- image = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{items.iloc[idx + j]['image_id']}.png"
37
- st.image(image, use_column_width=True)
38
-
39
- # handel checkbox information
40
- prompt_id = items.iloc[idx + j]['prompt_id']
41
- modelVersion_id = items.iloc[idx + j]['modelVersion_id']
42
-
43
- check_init = True if modelVersion_id in st.session_state.selected_dict.get(prompt_id, []) else False
44
-
45
- st.write("Position: ", idx + j)
46
-
47
- # show checkbox
48
- st.checkbox('Select', key=f'select_{prompt_id}_{modelVersion_id}', value=check_init)
49
-
50
- # show selected info
51
- for key in info:
52
- st.write(f"**{key}**: {items.iloc[idx + j][key]}")
53
-
54
- def selection_panel(self, items):
55
- # temperal function
56
-
57
- selecters = st.columns([1, 4])
58
-
59
- if 'score_weights' not in st.session_state:
60
- st.session_state.score_weights = [1.0, 0.8, 0.2, 0.8]
61
-
62
- # select sort type
63
- with selecters[0]:
64
- sort_type = st.selectbox('Sort by', ['Scores', 'IDs and Names'])
65
- if sort_type == 'Scores':
66
- sort_by = 'weighted_score_sum'
67
-
68
- # select other options
69
- with selecters[1]:
70
- if sort_type == 'IDs and Names':
71
- sub_selecters = st.columns([3, 1])
72
- # select sort by
73
- with sub_selecters[0]:
74
- sort_by = st.selectbox('Sort by',
75
- ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id', 'norm_nsfw'],
76
- label_visibility='hidden')
77
-
78
- continue_idx = 1
79
-
80
- else:
81
- # add custom weights
82
- sub_selecters = st.columns([1, 1, 1, 1])
83
-
84
- with sub_selecters[0]:
85
- clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[0], step=0.1, help='the weight for normalized clip score')
86
- with sub_selecters[1]:
87
- mcos_weight = st.number_input('Dissimilarity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[1], step=0.1, help='the weight for m(eam) s(imilarity) q(antile) score for measuring distinctiveness')
88
- with sub_selecters[2]:
89
- pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[2], step=0.1, help='the weight for normalized popularity score')
90
-
91
- items.loc[:, 'weighted_score_sum'] = round(items[f'norm_clip'] * clip_weight + items[f'norm_mcos'] * mcos_weight + items[
92
- 'norm_pop'] * pop_weight, 4)
93
-
94
- continue_idx = 3
95
-
96
- # save latest weights
97
- st.session_state.score_weights[0] = clip_weight
98
- st.session_state.score_weights[1] = mcos_weight
99
- st.session_state.score_weights[2] = pop_weight
100
-
101
- # select threshold
102
- with sub_selecters[continue_idx]:
103
- nsfw_threshold = st.number_input('NSFW Score Threshold', min_value=0.0, max_value=1.0, value=st.session_state.score_weights[3], step=0.01, help='Only show models with nsfw score lower than this threshold, set 1.0 to show all images')
104
- items = items[items['norm_nsfw'] <= nsfw_threshold].reset_index(drop=True)
105
-
106
- # save latest threshold
107
- st.session_state.score_weights[3] = nsfw_threshold
108
-
109
- # draw a distribution histogram
110
- if sort_type == 'Scores':
111
- try:
112
- with st.expander('Show score distribution histogram and select score range'):
113
- st.write('**Score distribution histogram**')
114
- chart_space = st.container()
115
- # st.write('Select the range of scores to show')
116
- hist_data = pd.DataFrame(items[sort_by])
117
- mini = hist_data[sort_by].min().item()
118
- mini = mini//0.1 * 0.1
119
- maxi = hist_data[sort_by].max().item()
120
- maxi = maxi//0.1 * 0.1 + 0.1
121
- st.write('**Select the range of scores to show**')
122
- r = st.slider('Select the range of scores to show', min_value=mini, max_value=maxi, value=(mini, maxi), step=0.05, label_visibility='collapsed')
123
- with chart_space:
124
- st.altair_chart(altair_histogram(hist_data, sort_by, r[0], r[1]), use_container_width=True)
125
- # event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
126
- # r = event_dict.get(sort_by)
127
- if r:
128
- items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
129
- # st.write(r)
130
- except:
131
- pass
132
-
133
- display_options = st.columns([1, 4])
134
-
135
- with display_options[0]:
136
- # select order
137
- order = st.selectbox('Order', ['Ascending', 'Descending'], index=1 if sort_type == 'Scores' else 0)
138
- if order == 'Ascending':
139
- order = True
140
- else:
141
- order = False
142
-
143
- with display_options[1]:
144
-
145
- # select info to show
146
- info = st.multiselect('Show Info',
147
- ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id',
148
- 'weighted_score_sum', 'model_download_count', 'clip_score', 'mcos_score',
149
- 'nsfw_score', 'norm_nsfw'],
150
- default=sort_by)
151
-
152
- # apply sorting to dataframe
153
- items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
154
-
155
- # select number of columns
156
- col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
157
-
158
- return items, info, col_num
159
-
160
- def sidebar(self):
161
- with st.sidebar:
162
- prompt_tags = self.promptBook['tag'].unique()
163
- # sort tags by alphabetical order
164
- prompt_tags = np.sort(prompt_tags)[::-1]
165
-
166
- tag = st.selectbox('Select a tag', prompt_tags)
167
-
168
- items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
169
-
170
- prompts = np.sort(items['prompt'].unique())[::-1]
171
-
172
- selected_prompt = st.selectbox('Select prompt', prompts)
173
-
174
- items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
175
- prompt_id = items['prompt_id'].unique()[0]
176
- note = items['note'].unique()[0]
177
-
178
- # show source
179
- if isinstance(note, str):
180
- if note.isdigit():
181
- st.caption(f"`Source: civitai`")
182
- else:
183
- st.caption(f"`Source: {note}`")
184
- else:
185
- st.caption("`Source: Parti-prompts`")
186
-
187
- # show image metadata
188
- image_metadatas = ['prompt_id', 'prompt', 'negativePrompt', 'sampler', 'cfgScale', 'size', 'seed']
189
- for key in image_metadatas:
190
- label = ' '.join(key.split('_')).capitalize()
191
- st.write(f"**{label}**")
192
- if items[key][0] == ' ':
193
- st.write('`None`')
194
- else:
195
- st.caption(f"{items[key][0]}")
196
-
197
- # for note as civitai image id, add civitai reference
198
- if isinstance(note, str) and note.isdigit():
199
- try:
200
- st.write(f'**[Civitai Reference](https://civitai.com/images/{note})**')
201
- res = requests.get(f'https://civitai.com/images/{note}')
202
- # st.write(res.text)
203
- soup = BeautifulSoup(res.text, 'html.parser')
204
- image_section = soup.find('div', {'class': 'mantine-12rlksp'})
205
- image_url = image_section.find('img')['src']
206
- st.image(image_url, use_column_width=True)
207
- except:
208
- pass
209
-
210
- return prompt_tags, tag, prompt_id, items
211
-
212
- def app(self):
213
- st.title('Model Visualization and Retrieval')
214
- st.write('This is a gallery of images generated by the models')
215
-
216
- prompt_tags, tag, prompt_id, items = self.sidebar()
217
-
218
- # add safety check for some prompts
219
- safety_check = True
220
- unsafe_prompts = {}
221
- # initialize unsafe prompts
222
- for prompt_tag in prompt_tags:
223
- unsafe_prompts[prompt_tag] = []
224
- # manually add unsafe prompts
225
- unsafe_prompts['world knowledge'] = [83]
226
- # unsafe_prompts['art'] = [23]
227
- unsafe_prompts['abstract'] = [1, 3]
228
- # unsafe_prompts['food'] = [34]
229
-
230
- if int(prompt_id.item()) in unsafe_prompts[tag]:
231
- st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
232
- safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'{prompt_id}')
233
-
234
- if safety_check:
235
- items, info, col_num = self.selection_panel(items)
236
-
237
- if 'selected_dict' in st.session_state:
238
- st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
239
- dynamic_weight_options = ['Grid Search', 'SVM', 'Greedy']
240
- dynamic_weight_panel = st.columns(len(dynamic_weight_options))
241
-
242
- if len(st.session_state.selected_dict.get(prompt_id, [])) > 0:
243
- btn_disable = False
244
- else:
245
- btn_disable = True
246
-
247
- for i in range(len(dynamic_weight_options)):
248
- method = dynamic_weight_options[i]
249
- with dynamic_weight_panel[i]:
250
- btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items, method))
251
-
252
- with st.form(key=f'{prompt_id}'):
253
- # buttons = st.columns([1, 1, 1])
254
- buttons_space = st.columns([1, 1, 1, 1])
255
- gallery_space = st.empty()
256
-
257
- with buttons_space[0]:
258
- continue_btn = st.form_submit_button('Confirm Selection', use_container_width=True, type='primary')
259
- if continue_btn:
260
- self.submit_actions('Continue', prompt_id)
261
-
262
- with buttons_space[1]:
263
- select_btn = st.form_submit_button('Select All', use_container_width=True)
264
- if select_btn:
265
- self.submit_actions('Select', prompt_id)
266
-
267
- with buttons_space[2]:
268
- deselect_btn = st.form_submit_button('Deselect All', use_container_width=True)
269
- if deselect_btn:
270
- self.submit_actions('Deselect', prompt_id)
271
-
272
- with buttons_space[3]:
273
- refresh_btn = st.form_submit_button('Refresh', on_click=gallery_space.empty, use_container_width=True)
274
-
275
- with gallery_space.container():
276
- with st.spinner('Loading images...'):
277
- self.gallery_standard(items, col_num, info)
278
-
279
- def submit_actions(self, status, prompt_id):
280
- if status == 'Select':
281
- modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
282
- st.session_state.selected_dict[prompt_id] = modelVersions.tolist()
283
- print(st.session_state.selected_dict, 'select')
284
- st.experimental_rerun()
285
- elif status == 'Deselect':
286
- st.session_state.selected_dict[prompt_id] = []
287
- print(st.session_state.selected_dict, 'deselect')
288
- st.experimental_rerun()
289
- # self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
290
- elif status == 'Continue':
291
- st.session_state.selected_dict[prompt_id] = []
292
- for key in st.session_state:
293
- keys = key.split('_')
294
- if keys[0] == 'select' and keys[1] == str(prompt_id):
295
- if st.session_state[key]:
296
- st.session_state.selected_dict[prompt_id].append(int(keys[2]))
297
- # switch_page("ranking")
298
- print(st.session_state.selected_dict, 'continue')
299
- st.experimental_rerun()
300
-
301
- def dynamic_weight(self, prompt_id, items, method='Grid Search'):
302
- selected = items[
303
- items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
304
- optimal_weight = [0, 0, 0]
305
-
306
- if method == 'Grid Search':
307
- # grid search method
308
- top_ranking = len(items) * len(selected)
309
-
310
- for clip_weight in np.arange(-1, 1, 0.1):
311
- for mcos_weight in np.arange(-1, 1, 0.1):
312
- for pop_weight in np.arange(-1, 1, 0.1):
313
-
314
- weight_all = clip_weight*items[f'norm_clip'] + mcos_weight*items[f'norm_mcos'] + pop_weight*items['norm_pop']
315
- weight_all_sorted = weight_all.sort_values(ascending=False).reset_index(drop=True)
316
- # print('weight_all_sorted:', weight_all_sorted)
317
- weight_selected = clip_weight*selected[f'norm_clip'] + mcos_weight*selected[f'norm_mcos'] + pop_weight*selected['norm_pop']
318
-
319
- # get the index of values of weight_selected in weight_all_sorted
320
- rankings = []
321
- for weight in weight_selected:
322
- rankings.append(weight_all_sorted.index[weight_all_sorted == weight].tolist()[0])
323
- if sum(rankings) <= top_ranking:
324
- top_ranking = sum(rankings)
325
- print('current top ranking:', top_ranking, rankings)
326
- optimal_weight = [clip_weight, mcos_weight, pop_weight]
327
- print('optimal weight:', optimal_weight)
328
-
329
- elif method == 'SVM':
330
- # svm method
331
- print('start svm method')
332
- # get residual dataframe that contains models not selected
333
- residual = items[~items['modelVersion_id'].isin(selected['modelVersion_id'])].reset_index(drop=True)
334
- residual = residual[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
335
- residual = residual.to_numpy()
336
- selected = selected[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
337
- selected = selected.to_numpy()
338
-
339
- y = np.concatenate((np.full((len(selected), 1), -1), np.full((len(residual), 1), 1)), axis=0).ravel()
340
- X = np.concatenate((selected, residual), axis=0)
341
-
342
- # fit svm model, and get parameters for the hyperplane
343
- clf = LinearSVC(random_state=0, C=1.0, fit_intercept=False, dual='auto')
344
- clf.fit(X, y)
345
- optimal_weight = clf.coef_[0].tolist()
346
- print('optimal weight:', optimal_weight)
347
- pass
348
-
349
- elif method == 'Greedy':
350
- for idx in selected.index:
351
- # find which score is the highest, clip, mcos, or pop
352
- clip_score = selected.loc[idx, 'norm_clip_crop']
353
- mcos_score = selected.loc[idx, 'norm_mcos_crop']
354
- pop_score = selected.loc[idx, 'norm_pop']
355
- if clip_score >= mcos_score and clip_score >= pop_score:
356
- optimal_weight[0] += 1
357
- elif mcos_score >= clip_score and mcos_score >= pop_score:
358
- optimal_weight[1] += 1
359
- elif pop_score >= clip_score and pop_score >= mcos_score:
360
- optimal_weight[2] += 1
361
-
362
- # normalize optimal_weight
363
- optimal_weight = [round(weight/len(selected), 2) for weight in optimal_weight]
364
- print('optimal weight:', optimal_weight)
365
-
366
- st.session_state.score_weights[0: 3] = optimal_weight
367
-
368
-
369
- # hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["x"])
370
- @st.cache_resource
371
- def altair_histogram(hist_data, sort_by, mini, maxi):
372
- brushed = alt.selection_interval(encodings=['x'], name="brushed")
373
-
374
- chart = (
375
- alt.Chart(hist_data)
376
- .mark_bar(opacity=0.7, cornerRadius=2)
377
- .encode(alt.X(f"{sort_by}:Q", bin=alt.Bin(maxbins=25)), y="count()")
378
- # .add_selection(brushed)
379
- # .properties(width=800, height=300)
380
- )
381
-
382
- # Create a transparent rectangle for highlighting the range
383
- highlight = (
384
- alt.Chart(pd.DataFrame({'x1': [mini], 'x2': [maxi]}))
385
- .mark_rect(opacity=0.3)
386
- .encode(x='x1', x2='x2')
387
- # .properties(width=800, height=300)
388
- )
389
-
390
- # Layer the chart and the highlight rectangle
391
- layered_chart = alt.layer(chart, highlight)
392
-
393
- return layered_chart
394
-
395
-
396
- @st.cache_data
397
- def load_hf_dataset():
398
- # login to huggingface
399
- login(token=os.environ.get("HF_TOKEN"))
400
-
401
- # load from huggingface
402
- roster = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferRoster', split='train'))
403
- promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train'))
404
- # images_ds = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
405
- images_ds = None # set to None for now since we use s3 bucket to store images
406
-
407
- # process dataset
408
- roster = roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name',
409
- 'model_download_count']].drop_duplicates().reset_index(drop=True)
410
-
411
- # add 'custom_score_weights' column to promptBook if not exist
412
- if 'weighted_score_sum' not in promptBook.columns:
413
- promptBook.loc[:, 'weighted_score_sum'] = 0
414
-
415
- # merge roster and promptbook
416
- promptBook = promptBook.merge(roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']],
417
- on=['model_id', 'modelVersion_id'], how='left')
418
-
419
- # add column to record current row index
420
- promptBook.loc[:, 'row_idx'] = promptBook.index
421
-
422
- return roster, promptBook, images_ds
423
-
424
-
425
- if __name__ == "__main__":
426
- st.set_page_config(page_title="Model Coffer Gallery", page_icon="🖼️", layout="wide")
427
-
428
- # remove ranking in the session state if it is created in Ranking.py
429
- st.session_state.pop('ranking', None)
430
-
431
- if 'user_id' not in st.session_state:
432
- st.warning('Please log in first.')
433
- home_btn = st.button('Go to Home Page')
434
- if home_btn:
435
- switch_page("home")
436
- else:
437
- st.write('You have already logged in as ' + st.session_state.user_id[0])
438
- roster, promptBook, images_ds = load_hf_dataset()
439
- # print(promptBook.columns)
440
-
441
- # initialize selected_dict
442
- if 'selected_dict' not in st.session_state:
443
- st.session_state['selected_dict'] = {}
444
-
445
- app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
446
- app.app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/Gallery.py CHANGED
@@ -14,353 +14,358 @@ from sklearn.svm import LinearSVC
14
 
15
  SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'msq_score', 'pop': 'model_download_count'}
16
 
17
- def gallery_standard(items, col_num, info):
18
- rows = len(items) // col_num + 1
19
- containers = [st.container() for _ in range(rows)]
20
- for idx in range(0, len(items), col_num):
21
- row_idx = idx // col_num
22
- with containers[row_idx]:
23
- cols = st.columns(col_num)
24
- for j in range(col_num):
25
- if idx + j < len(items):
26
- with cols[j]:
27
- # show image
28
- # image = self.images_ds[items.iloc[idx + j]['row_idx'].item()]['image']
29
- # image = f"https://modelcofferbucket.s3.us-east-2.amazonaws.com/{items.iloc[idx + j]['image_id']}.png"
30
- image = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{items.iloc[idx + j]['image_id']}.png"
31
- st.image(image, use_column_width=True)
32
-
33
- # handel checkbox information
34
- prompt_id = items.iloc[idx + j]['prompt_id']
35
- modelVersion_id = items.iloc[idx + j]['modelVersion_id']
36
-
37
- check_init = True if modelVersion_id in st.session_state.selected_dict.get(prompt_id, []) else False
38
-
39
- st.write("Position: ", idx + j)
40
-
41
- # show checkbox
42
- st.checkbox('Select', key=f'select_{prompt_id}_{modelVersion_id}', value=check_init)
43
-
44
- # show selected info
45
- for key in info:
46
- st.write(f"**{key}**: {items.iloc[idx + j][key]}")
47
-
48
- def selection_panel(items):
49
- # temperal function
50
-
51
- selecters = st.columns([1, 4])
52
-
53
- if 'score_weights' not in st.session_state:
54
- st.session_state.score_weights = [1.0, 0.8, 0.2, 0.8]
55
-
56
- # select sort type
57
- with selecters[0]:
58
- sort_type = st.selectbox('Sort by', ['Scores', 'IDs and Names'])
59
- if sort_type == 'Scores':
60
- sort_by = 'weighted_score_sum'
61
-
62
- # select other options
63
- with selecters[1]:
64
- if sort_type == 'IDs and Names':
65
- sub_selecters = st.columns([3, 1])
66
- # select sort by
67
- with sub_selecters[0]:
68
- sort_by = st.selectbox('Sort by',
69
- ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id', 'norm_nsfw'],
70
- label_visibility='hidden')
71
-
72
- continue_idx = 1
73
-
74
- else:
75
- # add custom weights
76
- sub_selecters = st.columns([1, 1, 1, 1])
77
-
78
- with sub_selecters[0]:
79
- clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[0], step=0.1, help='the weight for normalized clip score')
80
- with sub_selecters[1]:
81
- mcos_weight = st.number_input('Dissimilarity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[1], step=0.1, help='the weight for m(eam) s(imilarity) q(antile) score for measuring distinctiveness')
82
- with sub_selecters[2]:
83
- pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[2], step=0.1, help='the weight for normalized popularity score')
84
-
85
- items.loc[:, 'weighted_score_sum'] = round(items[f'norm_clip'] * clip_weight + items[f'norm_mcos'] * mcos_weight + items[
86
- 'norm_pop'] * pop_weight, 4)
87
-
88
- continue_idx = 3
89
-
90
- # save latest weights
91
- st.session_state.score_weights[0] = clip_weight
92
- st.session_state.score_weights[1] = mcos_weight
93
- st.session_state.score_weights[2] = pop_weight
94
-
95
- # select threshold
96
- with sub_selecters[continue_idx]:
97
- nsfw_threshold = st.number_input('NSFW Score Threshold', min_value=0.0, max_value=1.0, value=st.session_state.score_weights[3], step=0.01, help='Only show models with nsfw score lower than this threshold, set 1.0 to show all images')
98
- items = items[items['norm_nsfw'] <= nsfw_threshold].reset_index(drop=True)
99
-
100
- # save latest threshold
101
- st.session_state.score_weights[3] = nsfw_threshold
102
-
103
- # draw a distribution histogram
104
- if sort_type == 'Scores':
105
- try:
106
- with st.expander('Show score distribution histogram and select score range'):
107
- st.write('**Score distribution histogram**')
108
- chart_space = st.container()
109
- # st.write('Select the range of scores to show')
110
- hist_data = pd.DataFrame(items[sort_by])
111
- mini = hist_data[sort_by].min().item()
112
- mini = mini//0.1 * 0.1
113
- maxi = hist_data[sort_by].max().item()
114
- maxi = maxi//0.1 * 0.1 + 0.1
115
- st.write('**Select the range of scores to show**')
116
- r = st.slider('Select the range of scores to show', min_value=mini, max_value=maxi, value=(mini, maxi), step=0.05, label_visibility='collapsed')
117
- with chart_space:
118
- st.altair_chart(altair_histogram(hist_data, sort_by, r[0], r[1]), use_container_width=True)
119
- # event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
120
- # r = event_dict.get(sort_by)
121
- if r:
122
- items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
123
- # st.write(r)
124
- except:
125
- pass
126
 
127
- display_options = st.columns([1, 4])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- with display_options[0]:
130
- # select order
131
- order = st.selectbox('Order', ['Ascending', 'Descending'], index=1 if sort_type == 'Scores' else 0)
132
- if order == 'Ascending':
133
- order = True
134
- else:
135
- order = False
136
 
137
- with display_options[1]:
 
 
 
 
 
138
 
139
- # select info to show
140
- info = st.multiselect('Show Info',
141
- ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id',
142
- 'weighted_score_sum', 'model_download_count', 'clip_score', 'mcos_score',
143
- 'nsfw_score', 'norm_nsfw'],
144
- default=sort_by)
145
 
146
- # apply sorting to dataframe
147
- items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
148
 
149
- # select number of columns
150
- col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
 
 
151
 
152
- return items, info, col_num
 
 
 
153
 
154
- def sidebar(promptBook, images_ds):
155
- with st.sidebar:
156
- prompt_tags = promptBook['tag'].unique()
157
- # sort tags by alphabetical order
158
- prompt_tags = np.sort(prompt_tags)[::-1]
159
 
160
- tag = st.selectbox('Select a tag', prompt_tags)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
- items = promptBook[promptBook['tag'] == tag].reset_index(drop=True)
163
 
164
- prompts = np.sort(items['prompt'].unique())[::-1]
 
 
 
 
 
 
165
 
166
- selected_prompt = st.selectbox('Select prompt', prompts)
167
 
168
- items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
169
- prompt_id = items['prompt_id'].unique()[0]
170
- note = items['note'].unique()[0]
 
 
 
171
 
172
- # show source
173
- if isinstance(note, str):
174
- if note.isdigit():
175
- st.caption(f"`Source: civitai`")
176
- else:
177
- st.caption(f"`Source: {note}`")
178
- else:
179
- st.caption("`Source: Parti-prompts`")
180
-
181
- # show image metadata
182
- image_metadatas = ['prompt_id', 'prompt', 'negativePrompt', 'sampler', 'cfgScale', 'size', 'seed']
183
- for key in image_metadatas:
184
- label = ' '.join(key.split('_')).capitalize()
185
- st.write(f"**{label}**")
186
- if items[key][0] == ' ':
187
- st.write('`None`')
188
- else:
189
- st.caption(f"{items[key][0]}")
190
 
191
- # for note as civitai image id, add civitai reference
192
- if isinstance(note, str) and note.isdigit():
193
- try:
194
- st.write(f'**[Civitai Reference](https://civitai.com/images/{note})**')
195
- res = requests.get(f'https://civitai.com/images/{note}')
196
- # st.write(res.text)
197
- soup = BeautifulSoup(res.text, 'html.parser')
198
- image_section = soup.find('div', {'class': 'mantine-12rlksp'})
199
- image_url = image_section.find('img')['src']
200
- st.image(image_url, use_column_width=True)
201
- except:
202
- pass
203
 
204
- return prompt_tags, tag, prompt_id, items
205
 
206
- def app(promptBook, images_ds):
207
- st.title('Model Visualization and Retrieval')
208
- st.write('This is a gallery of images generated by the models')
 
 
209
 
210
- prompt_tags, tag, prompt_id, items = sidebar(promptBook, images_ds)
211
 
212
- # add safety check for some prompts
213
- safety_check = True
214
- unsafe_prompts = {}
215
- # initialize unsafe prompts
216
- for prompt_tag in prompt_tags:
217
- unsafe_prompts[prompt_tag] = []
218
- # manually add unsafe prompts
219
- unsafe_prompts['world knowledge'] = [83]
220
- # unsafe_prompts['art'] = [23]
221
- unsafe_prompts['abstract'] = [1, 3]
222
- # unsafe_prompts['food'] = [34]
223
 
224
- if int(prompt_id.item()) in unsafe_prompts[tag]:
225
- st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
226
- safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'{prompt_id}')
227
 
228
- if safety_check:
229
- items, info, col_num = selection_panel(items)
230
 
231
- if 'selected_dict' in st.session_state:
232
- st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
233
- dynamic_weight_options = ['Grid Search', 'SVM', 'Greedy']
234
- dynamic_weight_panel = st.columns(len(dynamic_weight_options))
235
 
236
- if len(st.session_state.selected_dict.get(prompt_id, [])) > 0:
237
- btn_disable = False
 
 
 
 
238
  else:
239
- btn_disable = True
240
-
241
- for i in range(len(dynamic_weight_options)):
242
- method = dynamic_weight_options[i]
243
- with dynamic_weight_panel[i]:
244
- btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=dynamic_weight, args=(prompt_id, items, method))
245
-
246
- with st.form(key=f'{prompt_id}'):
247
-
248
- buttons_space = st.columns([1, 1, 1, 1])
249
- # gallery_space = st.empty()
250
- #
251
- with buttons_space[0]:
252
- continue_btn = st.form_submit_button('Confirm Selection', use_container_width=True, type='primary')
253
- if continue_btn:
254
- submit_actions('Continue', prompt_id)
255
-
256
- with buttons_space[1]:
257
- select_btn = st.form_submit_button('Select All', use_container_width=True)
258
- if select_btn:
259
- submit_actions('Select', prompt_id)
260
-
261
- with buttons_space[2]:
262
- deselect_btn = st.form_submit_button('Deselect All', use_container_width=True)
263
- if deselect_btn:
264
- submit_actions('Deselect', prompt_id)
265
-
266
- # with buttons_space[3]:
267
- # refresh_btn = st.form_submit_button('Refresh', on_click=gallery_space.empty, use_container_width=True)
268
- #
269
- # with gallery_space.container():
270
- # with st.spinner('Loading images...'):
271
- # gallery_standard(items, col_num, info)
272
-
273
- for i in range(100):
274
- st.write('placeholder')
275
-
276
- def submit_actions(status, prompt_id):
277
- if status == 'Select':
278
- modelVersions = promptBook[promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
279
- st.session_state.selected_dict[prompt_id] = modelVersions.tolist()
280
- print(st.session_state.selected_dict, 'select')
281
- st.experimental_rerun()
282
- elif status == 'Deselect':
283
- st.session_state.selected_dict[prompt_id] = []
284
- print(st.session_state.selected_dict, 'deselect')
285
- st.experimental_rerun()
286
- # self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
287
- elif status == 'Continue':
288
- st.session_state.selected_dict[prompt_id] = []
289
- for key in st.session_state:
290
- keys = key.split('_')
291
- if keys[0] == 'select' and keys[1] == str(prompt_id):
292
- if st.session_state[key]:
293
- st.session_state.selected_dict[prompt_id].append(int(keys[2]))
294
- # switch_page("ranking")
295
- print(st.session_state.selected_dict, 'continue')
296
- st.experimental_rerun()
297
-
298
- def dynamic_weight(prompt_id, items, method='Grid Search'):
299
- selected = items[
300
- items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
301
- optimal_weight = [0, 0, 0]
302
-
303
- if method == 'Grid Search':
304
- # grid search method
305
- top_ranking = len(items) * len(selected)
306
-
307
- for clip_weight in np.arange(-1, 1, 0.1):
308
- for mcos_weight in np.arange(-1, 1, 0.1):
309
- for pop_weight in np.arange(-1, 1, 0.1):
310
-
311
- weight_all = clip_weight*items[f'norm_clip'] + mcos_weight*items[f'norm_mcos'] + pop_weight*items['norm_pop']
312
- weight_all_sorted = weight_all.sort_values(ascending=False).reset_index(drop=True)
313
- # print('weight_all_sorted:', weight_all_sorted)
314
- weight_selected = clip_weight*selected[f'norm_clip'] + mcos_weight*selected[f'norm_mcos'] + pop_weight*selected['norm_pop']
315
-
316
- # get the index of values of weight_selected in weight_all_sorted
317
- rankings = []
318
- for weight in weight_selected:
319
- rankings.append(weight_all_sorted.index[weight_all_sorted == weight].tolist()[0])
320
- if sum(rankings) <= top_ranking:
321
- top_ranking = sum(rankings)
322
- print('current top ranking:', top_ranking, rankings)
323
- optimal_weight = [clip_weight, mcos_weight, pop_weight]
324
- print('optimal weight:', optimal_weight)
325
-
326
- elif method == 'SVM':
327
- # svm method
328
- print('start svm method')
329
- # get residual dataframe that contains models not selected
330
- residual = items[~items['modelVersion_id'].isin(selected['modelVersion_id'])].reset_index(drop=True)
331
- residual = residual[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
332
- residual = residual.to_numpy()
333
- selected = selected[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
334
- selected = selected.to_numpy()
335
-
336
- y = np.concatenate((np.full((len(selected), 1), -1), np.full((len(residual), 1), 1)), axis=0).ravel()
337
- X = np.concatenate((selected, residual), axis=0)
338
-
339
- # fit svm model, and get parameters for the hyperplane
340
- clf = LinearSVC(random_state=0, C=1.0, fit_intercept=False, dual='auto')
341
- clf.fit(X, y)
342
- optimal_weight = clf.coef_[0].tolist()
343
- print('optimal weight:', optimal_weight)
344
- pass
345
-
346
- elif method == 'Greedy':
347
- for idx in selected.index:
348
- # find which score is the highest, clip, mcos, or pop
349
- clip_score = selected.loc[idx, 'norm_clip_crop']
350
- mcos_score = selected.loc[idx, 'norm_mcos_crop']
351
- pop_score = selected.loc[idx, 'norm_pop']
352
- if clip_score >= mcos_score and clip_score >= pop_score:
353
- optimal_weight[0] += 1
354
- elif mcos_score >= clip_score and mcos_score >= pop_score:
355
- optimal_weight[1] += 1
356
- elif pop_score >= clip_score and pop_score >= mcos_score:
357
- optimal_weight[2] += 1
358
-
359
- # normalize optimal_weight
360
- optimal_weight = [round(weight/len(selected), 2) for weight in optimal_weight]
361
- print('optimal weight:', optimal_weight)
362
-
363
- st.session_state.score_weights[0: 3] = optimal_weight
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
 
366
  # hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["x"])
@@ -439,4 +444,5 @@ if __name__ == "__main__":
439
  if 'selected_dict' not in st.session_state:
440
  st.session_state['selected_dict'] = {}
441
 
442
- app(promptBook, images_ds)
 
 
14
 
15
  SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'msq_score', 'pop': 'model_download_count'}
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ class GalleryApp:
19
+ def __init__(self, promptBook, images_ds):
20
+ self.promptBook = promptBook
21
+ self.images_ds = images_ds
22
+
23
+ def gallery_standard(self, items, col_num, info):
24
+ rows = len(items) // col_num + 1
25
+ containers = [st.container() for _ in range(rows)]
26
+ for idx in range(0, len(items), col_num):
27
+ row_idx = idx // col_num
28
+ with containers[row_idx]:
29
+ cols = st.columns(col_num)
30
+ for j in range(col_num):
31
+ if idx + j < len(items):
32
+ with cols[j]:
33
+ # show image
34
+ # image = self.images_ds[items.iloc[idx + j]['row_idx'].item()]['image']
35
+ # image = f"https://modelcofferbucket.s3.us-east-2.amazonaws.com/{items.iloc[idx + j]['image_id']}.png"
36
+ image = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{items.iloc[idx + j]['image_id']}.png"
37
+ st.image(image, use_column_width=True)
38
+
39
+ # handel checkbox information
40
+ prompt_id = items.iloc[idx + j]['prompt_id']
41
+ modelVersion_id = items.iloc[idx + j]['modelVersion_id']
42
+
43
+ check_init = True if modelVersion_id in st.session_state.selected_dict.get(prompt_id, []) else False
44
+
45
+ st.write("Position: ", idx + j)
46
+
47
+ # show checkbox
48
+ st.checkbox('Select', key=f'select_{prompt_id}_{modelVersion_id}', value=check_init)
49
+
50
+ # show selected info
51
+ for key in info:
52
+ st.write(f"**{key}**: {items.iloc[idx + j][key]}")
53
+
54
+ def selection_panel(self, items):
55
+ # temperal function
56
+
57
+ selecters = st.columns([1, 4])
58
+
59
+ if 'score_weights' not in st.session_state:
60
+ st.session_state.score_weights = [1.0, 0.8, 0.2, 0.8]
61
+
62
+ # select sort type
63
+ with selecters[0]:
64
+ sort_type = st.selectbox('Sort by', ['Scores', 'IDs and Names'])
65
+ if sort_type == 'Scores':
66
+ sort_by = 'weighted_score_sum'
67
+
68
+ # select other options
69
+ with selecters[1]:
70
+ if sort_type == 'IDs and Names':
71
+ sub_selecters = st.columns([3, 1])
72
+ # select sort by
73
+ with sub_selecters[0]:
74
+ sort_by = st.selectbox('Sort by',
75
+ ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id', 'norm_nsfw'],
76
+ label_visibility='hidden')
77
+
78
+ continue_idx = 1
79
 
80
+ else:
81
+ # add custom weights
82
+ sub_selecters = st.columns([1, 1, 1, 1])
 
 
 
 
83
 
84
+ with sub_selecters[0]:
85
+ clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[0], step=0.1, help='the weight for normalized clip score')
86
+ with sub_selecters[1]:
87
+ mcos_weight = st.number_input('Dissimilarity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[1], step=0.1, help='the weight for m(eam) s(imilarity) q(antile) score for measuring distinctiveness')
88
+ with sub_selecters[2]:
89
+ pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[2], step=0.1, help='the weight for normalized popularity score')
90
 
91
+ items.loc[:, 'weighted_score_sum'] = round(items[f'norm_clip'] * clip_weight + items[f'norm_mcos'] * mcos_weight + items[
92
+ 'norm_pop'] * pop_weight, 4)
 
 
 
 
93
 
94
+ continue_idx = 3
 
95
 
96
+ # save latest weights
97
+ st.session_state.score_weights[0] = clip_weight
98
+ st.session_state.score_weights[1] = mcos_weight
99
+ st.session_state.score_weights[2] = pop_weight
100
 
101
+ # select threshold
102
+ with sub_selecters[continue_idx]:
103
+ nsfw_threshold = st.number_input('NSFW Score Threshold', min_value=0.0, max_value=1.0, value=st.session_state.score_weights[3], step=0.01, help='Only show models with nsfw score lower than this threshold, set 1.0 to show all images')
104
+ items = items[items['norm_nsfw'] <= nsfw_threshold].reset_index(drop=True)
105
 
106
+ # save latest threshold
107
+ st.session_state.score_weights[3] = nsfw_threshold
 
 
 
108
 
109
+ # draw a distribution histogram
110
+ if sort_type == 'Scores':
111
+ try:
112
+ with st.expander('Show score distribution histogram and select score range'):
113
+ st.write('**Score distribution histogram**')
114
+ chart_space = st.container()
115
+ # st.write('Select the range of scores to show')
116
+ hist_data = pd.DataFrame(items[sort_by])
117
+ mini = hist_data[sort_by].min().item()
118
+ mini = mini//0.1 * 0.1
119
+ maxi = hist_data[sort_by].max().item()
120
+ maxi = maxi//0.1 * 0.1 + 0.1
121
+ st.write('**Select the range of scores to show**')
122
+ r = st.slider('Select the range of scores to show', min_value=mini, max_value=maxi, value=(mini, maxi), step=0.05, label_visibility='collapsed')
123
+ with chart_space:
124
+ st.altair_chart(altair_histogram(hist_data, sort_by, r[0], r[1]), use_container_width=True)
125
+ # event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
126
+ # r = event_dict.get(sort_by)
127
+ if r:
128
+ items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
129
+ # st.write(r)
130
+ except:
131
+ pass
132
 
133
+ display_options = st.columns([1, 4])
134
 
135
+ with display_options[0]:
136
+ # select order
137
+ order = st.selectbox('Order', ['Ascending', 'Descending'], index=1 if sort_type == 'Scores' else 0)
138
+ if order == 'Ascending':
139
+ order = True
140
+ else:
141
+ order = False
142
 
143
+ with display_options[1]:
144
 
145
+ # select info to show
146
+ info = st.multiselect('Show Info',
147
+ ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id',
148
+ 'weighted_score_sum', 'model_download_count', 'clip_score', 'mcos_score',
149
+ 'nsfw_score', 'norm_nsfw'],
150
+ default=sort_by)
151
 
152
+ # apply sorting to dataframe
153
+ items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ # select number of columns
156
+ col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
 
 
 
 
 
 
 
 
 
 
157
 
158
+ return items, info, col_num
159
 
160
+ def sidebar(self):
161
+ with st.sidebar:
162
+ prompt_tags = self.promptBook['tag'].unique()
163
+ # sort tags by alphabetical order
164
+ prompt_tags = np.sort(prompt_tags)[::-1]
165
 
166
+ tag = st.selectbox('Select a tag', prompt_tags)
167
 
168
+ items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
 
 
 
 
 
 
 
 
 
 
169
 
170
+ prompts = np.sort(items['prompt'].unique())[::-1]
 
 
171
 
172
+ selected_prompt = st.selectbox('Select prompt', prompts)
 
173
 
174
+ items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
175
+ prompt_id = items['prompt_id'].unique()[0]
176
+ note = items['note'].unique()[0]
 
177
 
178
+ # show source
179
+ if isinstance(note, str):
180
+ if note.isdigit():
181
+ st.caption(f"`Source: civitai`")
182
+ else:
183
+ st.caption(f"`Source: {note}`")
184
  else:
185
+ st.caption("`Source: Parti-prompts`")
186
+
187
+ # show image metadata
188
+ image_metadatas = ['prompt_id', 'prompt', 'negativePrompt', 'sampler', 'cfgScale', 'size', 'seed']
189
+ for key in image_metadatas:
190
+ label = ' '.join(key.split('_')).capitalize()
191
+ st.write(f"**{label}**")
192
+ if items[key][0] == ' ':
193
+ st.write('`None`')
194
+ else:
195
+ st.caption(f"{items[key][0]}")
196
+
197
+ # for note as civitai image id, add civitai reference
198
+ if isinstance(note, str) and note.isdigit():
199
+ try:
200
+ st.write(f'**[Civitai Reference](https://civitai.com/images/{note})**')
201
+ res = requests.get(f'https://civitai.com/images/{note}')
202
+ # st.write(res.text)
203
+ soup = BeautifulSoup(res.text, 'html.parser')
204
+ image_section = soup.find('div', {'class': 'mantine-12rlksp'})
205
+ image_url = image_section.find('img')['src']
206
+ st.image(image_url, use_column_width=True)
207
+ except:
208
+ pass
209
+
210
+ return prompt_tags, tag, prompt_id, items
211
+
212
+ def app(self):
213
+ st.title('Model Visualization and Retrieval')
214
+ st.write('This is a gallery of images generated by the models')
215
+
216
+ prompt_tags, tag, prompt_id, items = self.sidebar()
217
+
218
+ # add safety check for some prompts
219
+ safety_check = True
220
+ unsafe_prompts = {}
221
+ # initialize unsafe prompts
222
+ for prompt_tag in prompt_tags:
223
+ unsafe_prompts[prompt_tag] = []
224
+ # manually add unsafe prompts
225
+ unsafe_prompts['world knowledge'] = [83]
226
+ # unsafe_prompts['art'] = [23]
227
+ unsafe_prompts['abstract'] = [1, 3]
228
+ # unsafe_prompts['food'] = [34]
229
+
230
+ if int(prompt_id.item()) in unsafe_prompts[tag]:
231
+ st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
232
+ safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'{prompt_id}')
233
+
234
+ if safety_check:
235
+ items, info, col_num = self.selection_panel(items)
236
+
237
+ if 'selected_dict' in st.session_state:
238
+ # st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
239
+ dynamic_weight_options = ['Grid Search', 'SVM', 'Greedy']
240
+ dynamic_weight_panel = st.columns(len(dynamic_weight_options))
241
+
242
+ if len(st.session_state.selected_dict.get(prompt_id, [])) > 0:
243
+ btn_disable = False
244
+ else:
245
+ btn_disable = True
246
+
247
+ for i in range(len(dynamic_weight_options)):
248
+ method = dynamic_weight_options[i]
249
+ with dynamic_weight_panel[i]:
250
+ btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items, method))
251
+
252
+ with st.form(key=f'{prompt_id}'):
253
+ # buttons = st.columns([1, 1, 1])
254
+ buttons_space = st.columns([1, 1, 1, 1])
255
+ gallery_space = st.empty()
256
+
257
+ with buttons_space[0]:
258
+ continue_btn = st.form_submit_button('Confirm Selection', use_container_width=True, type='primary')
259
+ if continue_btn:
260
+ self.submit_actions('Continue', prompt_id)
261
+
262
+ with buttons_space[1]:
263
+ select_btn = st.form_submit_button('Select All', use_container_width=True)
264
+ if select_btn:
265
+ self.submit_actions('Select', prompt_id)
266
+
267
+ with buttons_space[2]:
268
+ deselect_btn = st.form_submit_button('Deselect All', use_container_width=True)
269
+ if deselect_btn:
270
+ self.submit_actions('Deselect', prompt_id)
271
+
272
+ with buttons_space[3]:
273
+ refresh_btn = st.form_submit_button('Refresh', on_click=gallery_space.empty, use_container_width=True)
274
+
275
+ with gallery_space.container():
276
+ with st.spinner('Loading images...'):
277
+ self.gallery_standard(items, col_num, info)
278
+
279
+ prompt = st.chat_input(f"checked: {str(st.session_state.selected_dict.get(prompt_id, []))}", disabled=True, key=f'{prompt_id}')
280
+
281
+ def submit_actions(self, status, prompt_id):
282
+ if status == 'Select':
283
+ modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
284
+ st.session_state.selected_dict[prompt_id] = modelVersions.tolist()
285
+ print(st.session_state.selected_dict, 'select')
286
+ st.experimental_rerun()
287
+ elif status == 'Deselect':
288
+ st.session_state.selected_dict[prompt_id] = []
289
+ print(st.session_state.selected_dict, 'deselect')
290
+ st.experimental_rerun()
291
+ # self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
292
+ elif status == 'Continue':
293
+ st.session_state.selected_dict[prompt_id] = []
294
+ for key in st.session_state:
295
+ keys = key.split('_')
296
+ if keys[0] == 'select' and keys[1] == str(prompt_id):
297
+ if st.session_state[key]:
298
+ st.session_state.selected_dict[prompt_id].append(int(keys[2]))
299
+ # switch_page("ranking")
300
+ print(st.session_state.selected_dict, 'continue')
301
+ st.experimental_rerun()
302
+
303
+ def dynamic_weight(self, prompt_id, items, method='Grid Search'):
304
+ selected = items[
305
+ items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
306
+ optimal_weight = [0, 0, 0]
307
+
308
+ if method == 'Grid Search':
309
+ # grid search method
310
+ top_ranking = len(items) * len(selected)
311
+
312
+ for clip_weight in np.arange(-1, 1, 0.1):
313
+ for mcos_weight in np.arange(-1, 1, 0.1):
314
+ for pop_weight in np.arange(-1, 1, 0.1):
315
+
316
+ weight_all = clip_weight*items[f'norm_clip'] + mcos_weight*items[f'norm_mcos'] + pop_weight*items['norm_pop']
317
+ weight_all_sorted = weight_all.sort_values(ascending=False).reset_index(drop=True)
318
+ # print('weight_all_sorted:', weight_all_sorted)
319
+ weight_selected = clip_weight*selected[f'norm_clip'] + mcos_weight*selected[f'norm_mcos'] + pop_weight*selected['norm_pop']
320
+
321
+ # get the index of values of weight_selected in weight_all_sorted
322
+ rankings = []
323
+ for weight in weight_selected:
324
+ rankings.append(weight_all_sorted.index[weight_all_sorted == weight].tolist()[0])
325
+ if sum(rankings) <= top_ranking:
326
+ top_ranking = sum(rankings)
327
+ print('current top ranking:', top_ranking, rankings)
328
+ optimal_weight = [clip_weight, mcos_weight, pop_weight]
329
+ print('optimal weight:', optimal_weight)
330
+
331
+ elif method == 'SVM':
332
+ # svm method
333
+ print('start svm method')
334
+ # get residual dataframe that contains models not selected
335
+ residual = items[~items['modelVersion_id'].isin(selected['modelVersion_id'])].reset_index(drop=True)
336
+ residual = residual[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
337
+ residual = residual.to_numpy()
338
+ selected = selected[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
339
+ selected = selected.to_numpy()
340
+
341
+ y = np.concatenate((np.full((len(selected), 1), -1), np.full((len(residual), 1), 1)), axis=0).ravel()
342
+ X = np.concatenate((selected, residual), axis=0)
343
+
344
+ # fit svm model, and get parameters for the hyperplane
345
+ clf = LinearSVC(random_state=0, C=1.0, fit_intercept=False, dual='auto')
346
+ clf.fit(X, y)
347
+ optimal_weight = clf.coef_[0].tolist()
348
+ print('optimal weight:', optimal_weight)
349
+ pass
350
+
351
+ elif method == 'Greedy':
352
+ for idx in selected.index:
353
+ # find which score is the highest, clip, mcos, or pop
354
+ clip_score = selected.loc[idx, 'norm_clip_crop']
355
+ mcos_score = selected.loc[idx, 'norm_mcos_crop']
356
+ pop_score = selected.loc[idx, 'norm_pop']
357
+ if clip_score >= mcos_score and clip_score >= pop_score:
358
+ optimal_weight[0] += 1
359
+ elif mcos_score >= clip_score and mcos_score >= pop_score:
360
+ optimal_weight[1] += 1
361
+ elif pop_score >= clip_score and pop_score >= mcos_score:
362
+ optimal_weight[2] += 1
363
+
364
+ # normalize optimal_weight
365
+ optimal_weight = [round(weight/len(selected), 2) for weight in optimal_weight]
366
+ print('optimal weight:', optimal_weight)
367
+
368
+ st.session_state.score_weights[0: 3] = optimal_weight
369
 
370
 
371
  # hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["x"])
 
444
  if 'selected_dict' not in st.session_state:
445
  st.session_state['selected_dict'] = {}
446
 
447
+ app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
448
+ app.app()
pages/__pycache__/Gallery.cpython-39.pyc DELETED
Binary file (12.3 kB)
 
pages/streamlit-1.25.py CHANGED
@@ -62,17 +62,17 @@ if st.button('Three cheers'):
62
  if "chat_messages" not in st.session_state:
63
  st.session_state.chat_messages = []
64
 
65
- # prompt = st.chat_input("Say something")
66
- # if prompt:
67
- # st.session_state.chat_messages.append({"type": "user", "message": prompt})
68
- # st.session_state.chat_messages.append({"type": "bot", "message": "Hello!", "chart": np.random.randn(30, 3)})
69
- #
70
- # for message in st.session_state.chat_messages[::-1]:
71
- # if message["type"] == "user":
72
- # with st.chat_message("You"):
73
- # st.write(message["message"])
74
- # else:
75
- # with st.chat_message("Bot"):
76
- # st.write(message["message"])
77
- # st.line_chart(message["chart"])
78
 
 
62
  if "chat_messages" not in st.session_state:
63
  st.session_state.chat_messages = []
64
 
65
+ prompt = st.chat_input("Say something")
66
+ if prompt:
67
+ st.session_state.chat_messages.append({"type": "user", "message": prompt})
68
+ st.session_state.chat_messages.append({"type": "bot", "message": "Hello!", "chart": np.random.randn(30, 3)})
69
+
70
+ for message in st.session_state.chat_messages[::-1]:
71
+ if message["type"] == "user":
72
+ with st.chat_message("You"):
73
+ st.write(message["message"])
74
+ else:
75
+ with st.chat_message("Bot"):
76
+ st.write(message["message"])
77
+ st.line_chart(message["chart"])
78