Ricercar commited on
Commit
fb1a1d0
1 Parent(s): e0e6d97

add new backup

Browse files
Files changed (1) hide show
  1. Archive/Gallery_beta0920.py +718 -0
Archive/Gallery_beta0920.py ADDED
@@ -0,0 +1,718 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import requests
4
+
5
+ import altair as alt
6
+ import extra_streamlit_components as stx
7
+ import numpy as np
8
+ import pandas as pd
9
+ import streamlit as st
10
+ import streamlit.components.v1 as components
11
+
12
+ from bs4 import BeautifulSoup
13
+ from datasets import load_dataset, Dataset, load_from_disk
14
+ from huggingface_hub import login
15
+ from streamlit_agraph import agraph, Node, Edge, Config
16
+ from streamlit_extras.switch_page_button import switch_page
17
+ from streamlit_extras.no_default_selectbox import selectbox
18
+ from sklearn.svm import LinearSVC
19
+
20
+ SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'msq_score', 'pop': 'model_download_count'}
21
+
22
+
23
+ class GalleryApp:
24
+ def __init__(self, promptBook, images_ds):
25
+ self.promptBook = promptBook
26
+ self.images_ds = images_ds
27
+
28
+ # init gallery state
29
+ if 'gallery_state' not in st.session_state:
30
+ st.session_state.gallery_state = {}
31
+
32
+ # initialize selected_dict
33
+ if 'selected_dict' not in st.session_state:
34
+ st.session_state['selected_dict'] = {}
35
+
36
+ if 'gallery_focus' not in st.session_state:
37
+ st.session_state.gallery_focus = {'tag': None, 'prompt': None}
38
+
39
+ def gallery_standard(self, items, col_num, info):
40
+ rows = len(items) // col_num + 1
41
+ containers = [st.container() for _ in range(rows)]
42
+ for idx in range(0, len(items), col_num):
43
+ row_idx = idx // col_num
44
+ with containers[row_idx]:
45
+ cols = st.columns(col_num)
46
+ for j in range(col_num):
47
+ if idx + j < len(items):
48
+ with cols[j]:
49
+ # show image
50
+ # image = self.images_ds[items.iloc[idx + j]['row_idx'].item()]['image']
51
+ image = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{items.iloc[idx + j]['image_id']}.png"
52
+ st.image(image, use_column_width=True)
53
+
54
+ # handel checkbox information
55
+ prompt_id = items.iloc[idx + j]['prompt_id']
56
+ modelVersion_id = items.iloc[idx + j]['modelVersion_id']
57
+
58
+ check_init = True if modelVersion_id in st.session_state.selected_dict.get(prompt_id, []) else False
59
+
60
+ # st.write("Position: ", idx + j)
61
+
62
+ # show checkbox
63
+ st.checkbox('Select', key=f'select_{prompt_id}_{modelVersion_id}', value=check_init)
64
+
65
+ # show selected info
66
+ for key in info:
67
+ st.write(f"**{key}**: {items.iloc[idx + j][key]}")
68
+
69
+ def gallery_graph(self, items):
70
+ items = load_tsne_coordinates(items)
71
+
72
+ # sort items to be popularity from low to high, so that most popular ones will be on the top
73
+ items = items.sort_values(by=['model_download_count'], ascending=True).reset_index(drop=True)
74
+
75
+ scale = 50
76
+ items.loc[:, 'x'] = items['x'] * scale
77
+ items.loc[:, 'y'] = items['y'] * scale
78
+
79
+ nodes = []
80
+ edges = []
81
+
82
+ for idx in items.index:
83
+ # if items.loc[idx, 'modelVersion_id'] in st.session_state.selected_dict.get(items.loc[idx, 'prompt_id'], 0):
84
+ # opacity = 0.2
85
+ # else:
86
+ # opacity = 1.0
87
+
88
+ nodes.append(Node(id=items.loc[idx, 'image_id'],
89
+ # label=str(items.loc[idx, 'model_name']),
90
+ title=f"model name: {items.loc[idx, 'model_name']}\nmodelVersion name: {items.loc[idx, 'modelVersion_name']}\nclip score: {items.loc[idx, 'clip_score']}\nmcos score: {items.loc[idx, 'mcos_score']}\npopularity: {items.loc[idx, 'model_download_count']}",
91
+ size=20,
92
+ shape='image',
93
+ image=f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{items.loc[idx, 'image_id']}.png",
94
+ x=items.loc[idx, 'x'].item(),
95
+ y=items.loc[idx, 'y'].item(),
96
+ # fixed=True,
97
+ color={'background': '#E0E0E1', 'border': '#ffffff', 'highlight': {'border': '#F04542'}},
98
+ # opacity=opacity,
99
+ shadow={'enabled': True, 'color': 'rgba(0,0,0,0.4)', 'size': 10, 'x': 1, 'y': 1},
100
+ borderWidth=2,
101
+ shapeProperties={'useBorderWithImage': True},
102
+ )
103
+ )
104
+
105
+ config = Config(width='100%',
106
+ height='600',
107
+ directed=True,
108
+ physics=False,
109
+ hierarchical=False,
110
+ interaction={'navigationButtons': True, 'dragNodes': False, 'multiselect': False},
111
+ # **kwargs
112
+ )
113
+
114
+ return agraph(nodes=nodes,
115
+ edges=edges,
116
+ config=config,
117
+ )
118
+
119
+ def selection_panel(self, items):
120
+ # temperal function
121
+
122
+ selecters = st.columns([1, 4])
123
+
124
+ if 'score_weights' not in st.session_state:
125
+ # st.session_state.score_weights = [1.0, 0.8, 0.2, 0.8]
126
+ st.session_state.score_weights = [1.0, 0.8, 0.2]
127
+
128
+ # select sort type
129
+ with selecters[0]:
130
+ sort_type = st.selectbox('Sort by', ['Scores', 'IDs and Names'])
131
+ if sort_type == 'Scores':
132
+ sort_by = 'weighted_score_sum'
133
+
134
+ # select other options
135
+ with selecters[1]:
136
+ if sort_type == 'IDs and Names':
137
+ sub_selecters = st.columns([3])
138
+ # select sort by
139
+ with sub_selecters[0]:
140
+ sort_by = st.selectbox('Sort by',
141
+ ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id', 'norm_nsfw'],
142
+ label_visibility='hidden')
143
+
144
+ continue_idx = 1
145
+
146
+ else:
147
+ # add custom weights
148
+ sub_selecters = st.columns([1, 1, 1])
149
+
150
+ with sub_selecters[0]:
151
+ clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=1.0, step=0.1, help='the weight for normalized clip score')
152
+ with sub_selecters[1]:
153
+ mcos_weight = st.number_input('Dissimilarity Weight', min_value=-100.0, max_value=100.0, value=0.8, step=0.1, help='the weight for m(eam) s(imilarity) q(antile) score for measuring distinctiveness')
154
+ with sub_selecters[2]:
155
+ pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=0.2, step=0.1, help='the weight for normalized popularity score')
156
+
157
+ items.loc[:, 'weighted_score_sum'] = round(items[f'norm_clip'] * clip_weight + items[f'norm_mcos'] * mcos_weight + items[
158
+ 'norm_pop'] * pop_weight, 4)
159
+
160
+ continue_idx = 3
161
+
162
+ # save latest weights
163
+ st.session_state.score_weights[0] = round(clip_weight, 2)
164
+ st.session_state.score_weights[1] = round(mcos_weight, 2)
165
+ st.session_state.score_weights[2] = round(pop_weight, 2)
166
+
167
+ # # select threshold
168
+ # with sub_selecters[continue_idx]:
169
+ # nsfw_threshold = st.number_input('NSFW Score Threshold', min_value=0.0, max_value=1.0, value=0.8, step=0.01, help='Only show models with nsfw score lower than this threshold, set 1.0 to show all images')
170
+ # items = items[items['norm_nsfw'] <= nsfw_threshold].reset_index(drop=True)
171
+ #
172
+ # # save latest threshold
173
+ # st.session_state.score_weights[3] = nsfw_threshold
174
+
175
+ # # draw a distribution histogram
176
+ # if sort_type == 'Scores':
177
+ # try:
178
+ # with st.expander('Show score distribution histogram and select score range'):
179
+ # st.write('**Score distribution histogram**')
180
+ # chart_space = st.container()
181
+ # # st.write('Select the range of scores to show')
182
+ # hist_data = pd.DataFrame(items[sort_by])
183
+ # mini = hist_data[sort_by].min().item()
184
+ # mini = mini//0.1 * 0.1
185
+ # maxi = hist_data[sort_by].max().item()
186
+ # maxi = maxi//0.1 * 0.1 + 0.1
187
+ # st.write('**Select the range of scores to show**')
188
+ # r = st.slider('Select the range of scores to show', min_value=mini, max_value=maxi, value=(mini, maxi), step=0.05, label_visibility='collapsed')
189
+ # with chart_space:
190
+ # st.altair_chart(altair_histogram(hist_data, sort_by, r[0], r[1]), use_container_width=True)
191
+ # # event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
192
+ # # r = event_dict.get(sort_by)
193
+ # if r:
194
+ # items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
195
+ # # st.write(r)
196
+ # except:
197
+ # pass
198
+
199
+ display_options = st.columns([1, 4])
200
+
201
+ with display_options[0]:
202
+ # select order
203
+ order = st.selectbox('Order', ['Ascending', 'Descending'], index=1 if sort_type == 'Scores' else 0)
204
+ if order == 'Ascending':
205
+ order = True
206
+ else:
207
+ order = False
208
+
209
+ with display_options[1]:
210
+
211
+ # select info to show
212
+ info = st.multiselect('Show Info',
213
+ ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id',
214
+ 'weighted_score_sum', 'model_download_count', 'clip_score', 'mcos_score',
215
+ 'nsfw_score', 'norm_nsfw'],
216
+ default=sort_by)
217
+
218
+ # apply sorting to dataframe
219
+ items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
220
+
221
+ # select number of columns
222
+ col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
223
+
224
+ return items, info, col_num
225
+
226
+ def sidebar(self, items, prompt_id, note):
227
+ with st.sidebar:
228
+ # prompt_tags = self.promptBook['tag'].unique()
229
+ # # sort tags by alphabetical order
230
+ # prompt_tags = np.sort(prompt_tags)[::1]
231
+ #
232
+ # tag = st.selectbox('Select a tag', prompt_tags, index=5)
233
+ #
234
+ # items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
235
+ #
236
+ # prompts = np.sort(items['prompt'].unique())[::1]
237
+ #
238
+ # selected_prompt = st.selectbox('Select prompt', prompts, index=3)
239
+
240
+ # mode = st.radio('Select a mode', ['Gallery', 'Graph'], horizontal=True, index=1)
241
+
242
+ # items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
243
+
244
+ # st.title('Model Visualization and Retrieval')
245
+
246
+ # show source
247
+ if isinstance(note, str):
248
+ if note.isdigit():
249
+ st.caption(f"`Source: civitai`")
250
+ else:
251
+ st.caption(f"`Source: {note}`")
252
+ else:
253
+ st.caption("`Source: Parti-prompts`")
254
+
255
+ # show image metadata
256
+ image_metadatas = ['prompt', 'negativePrompt', 'sampler', 'cfgScale', 'size', 'seed']
257
+ for key in image_metadatas:
258
+ label = ' '.join(key.split('_')).capitalize()
259
+ st.write(f"**{label}**")
260
+ if items[key][0] == ' ':
261
+ st.write('`None`')
262
+ else:
263
+ st.caption(f"{items[key][0]}")
264
+
265
+ # for note as civitai image id, add civitai reference
266
+ if isinstance(note, str) and note.isdigit():
267
+ try:
268
+ st.write(f'**[Civitai Reference](https://civitai.com/images/{note})**')
269
+ res = requests.get(f'https://civitai.com/images/{note}')
270
+ # st.write(res.text)
271
+ soup = BeautifulSoup(res.text, 'html.parser')
272
+ image_section = soup.find('div', {'class': 'mantine-12rlksp'})
273
+ image_url = image_section.find('img')['src']
274
+ st.image(image_url, use_column_width=True)
275
+ except:
276
+ pass
277
+
278
+ # return prompt_tags, tag, prompt_id, items
279
+
280
+ def app(self):
281
+ st.write('### Model Visualization and Retrieval')
282
+ # st.write('This is a gallery of images generated by the models')
283
+
284
+ # build the tabular view
285
+ prompt_tags = self.promptBook['tag'].unique()
286
+ # sort tags by alphabetical order
287
+ prompt_tags = np.sort(prompt_tags)[::1].tolist()
288
+
289
+ # chosen_data = [stx.TabBarItemData(id=tag, title=tag, description='') for tag in prompt_tags]
290
+ # tag = stx.tab_bar(chosen_data, key='tag', default='food')
291
+
292
+ # save tag to session state on change
293
+ tag = st.radio('Select a tag', prompt_tags, index=5, horizontal=True, key='tag', label_visibility='collapsed')
294
+
295
+ # tabs = st.tabs(prompt_tags)
296
+ # for i in range(len(prompt_tags)):
297
+ # with tabs[i]:
298
+ # tag = prompt_tags[i]
299
+ items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
300
+
301
+ prompts = np.sort(items['prompt'].unique())[::1].tolist()
302
+
303
+ # st.caption('Select a prompt')
304
+ subset_selector = st.columns([3, 1])
305
+ with subset_selector[0]:
306
+ selected_prompt = selectbox('Select prompt', prompts, key=f'prompt_{tag}', no_selection_label='---', label_visibility='collapsed', index=0)
307
+ # st.session_state.prompt_idx_last_time = prompts.index(selected_prompt) if selected_prompt else 0
308
+
309
+ if selected_prompt is None:
310
+ # st.markdown(':orange[Please select a prompt above👆]')
311
+ st.write('**Feel free to navigate among tags and pages! Your selection will be saved within one log-in session.**')
312
+
313
+ with subset_selector[-1]:
314
+ st.write(':orange[👈 **Please select a prompt**]')
315
+
316
+ else:
317
+ items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
318
+ prompt_id = items['prompt_id'].unique()[0]
319
+ note = items['note'].unique()[0]
320
+
321
+ # add state to session state
322
+ if prompt_id not in st.session_state.gallery_state:
323
+ st.session_state.gallery_state[prompt_id] = 'graph'
324
+
325
+ # add focus to session state
326
+ st.session_state.gallery_focus['tag'] = tag
327
+ st.session_state.gallery_focus['prompt'] = selected_prompt
328
+
329
+ # add safety check for some prompts
330
+ safety_check = True
331
+
332
+ # load unsafe prompts
333
+ unsafe_prompts = json.load(open('./data/unsafe_prompts.json', 'r'))
334
+ for prompt_tag in prompt_tags:
335
+ if prompt_tag not in unsafe_prompts:
336
+ unsafe_prompts[prompt_tag] = []
337
+ # # manually add unsafe prompts
338
+ # unsafe_prompts['world knowledge'] = [83]
339
+ # unsafe_prompts['abstract'] = [1, 3]
340
+
341
+ if int(prompt_id.item()) in unsafe_prompts[tag]:
342
+ st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
343
+ safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
344
+
345
+ print('current state: ', st.session_state.gallery_state[prompt_id])
346
+
347
+ if st.session_state.gallery_state[prompt_id] == 'graph':
348
+ if safety_check:
349
+ self.graph_mode(prompt_id, items)
350
+ with subset_selector[-1]:
351
+ has_selection = False
352
+ try:
353
+ if len(st.session_state.selected_dict.get(prompt_id, [])) > 0:
354
+ has_selection = True
355
+ except:
356
+ pass
357
+
358
+ if has_selection:
359
+ checkout = st.button('Check out selections', use_container_width=True, type='primary')
360
+ if checkout:
361
+ print('checkout')
362
+
363
+ st.session_state.gallery_state[prompt_id] = 'gallery'
364
+ print(st.session_state.gallery_state[prompt_id])
365
+ st.experimental_rerun()
366
+ else:
367
+ st.write(':orange[👇 **Select images you like below**]')
368
+
369
+ elif st.session_state.gallery_state[prompt_id] == 'gallery':
370
+ items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(
371
+ drop=True)
372
+ self.gallery_mode(prompt_id, items)
373
+
374
+ with subset_selector[-1]:
375
+ state_operations = st.columns([1, 1])
376
+ with state_operations[0]:
377
+ back = st.button('Back to 🖼️', use_container_width=True)
378
+ if back:
379
+ st.session_state.gallery_state[prompt_id] = 'graph'
380
+ st.experimental_rerun()
381
+
382
+ with state_operations[1]:
383
+ forward = st.button('Check out', use_container_width=True, type='primary', on_click=self.submit_actions, args=('Continue', prompt_id))
384
+ if forward:
385
+ switch_page('ranking')
386
+
387
+ try:
388
+ self.sidebar(items, prompt_id, note)
389
+ except:
390
+ pass
391
+
392
+ def graph_mode(self, prompt_id, items):
393
+ graph_cols = st.columns([3, 1])
394
+ # prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}",
395
+ # disabled=False, key=f'{prompt_id}')
396
+ # if prompt:
397
+ # switch_page("ranking")
398
+
399
+ with graph_cols[0]:
400
+ graph_space = st.empty()
401
+
402
+ with graph_space.container():
403
+ return_value = self.gallery_graph(items)
404
+
405
+ with graph_cols[1]:
406
+ if return_value:
407
+ with st.form(key=f'{prompt_id}'):
408
+ image_url = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{return_value}.png"
409
+
410
+ st.image(image_url)
411
+
412
+ item = items[items['image_id'] == return_value].reset_index(drop=True).iloc[0]
413
+ modelVersion_id = item['modelVersion_id']
414
+
415
+ # handle selection
416
+ if 'selected_dict' in st.session_state:
417
+ if item['prompt_id'] not in st.session_state.selected_dict:
418
+ st.session_state.selected_dict[item['prompt_id']] = []
419
+
420
+ if modelVersion_id in st.session_state.selected_dict[item['prompt_id']]:
421
+ checked = True
422
+ else:
423
+ checked = False
424
+
425
+ if checked:
426
+ # deselect = st.button('Deselect', key=f'select_{item["prompt_id"]}_{item["modelVersion_id"]}', use_container_width=True)
427
+ deselect = st.form_submit_button('Deselect', use_container_width=True)
428
+ if deselect:
429
+ st.session_state.selected_dict[item['prompt_id']].remove(item['modelVersion_id'])
430
+ self.remove_ranking_states(item['prompt_id'])
431
+ st.experimental_rerun()
432
+
433
+ else:
434
+ # select = st.button('Select', key=f'select_{item["prompt_id"]}_{item["modelVersion_id"]}', use_container_width=True, type='primary')
435
+ select = st.form_submit_button('Select', use_container_width=True, type='primary')
436
+ if select:
437
+ st.session_state.selected_dict[item['prompt_id']].append(item['modelVersion_id'])
438
+ self.remove_ranking_states(item['prompt_id'])
439
+ st.experimental_rerun()
440
+
441
+ # st.write(item)
442
+ infos = ['model_name', 'modelVersion_name', 'model_download_count', 'clip_score', 'mcos_score',
443
+ 'nsfw_score']
444
+
445
+ infos_df = item[infos]
446
+ # rename columns
447
+ infos_df = infos_df.rename(index={'model_name': 'Model', 'modelVersion_name': 'Version', 'model_download_count': 'Downloads', 'clip_score': 'Clip Score', 'mcos_score': 'mcos Score', 'nsfw_score': 'NSFW Score'})
448
+ st.table(infos_df)
449
+
450
+ # for info in infos:
451
+ # st.write(f"**{info}**:")
452
+ # st.write(item[info])
453
+
454
+ else:
455
+ st.info('Please click on an image to show')
456
+
457
+ def gallery_mode(self, prompt_id, items):
458
+ items, info, col_num = self.selection_panel(items)
459
+
460
+ # if 'selected_dict' in st.session_state:
461
+ # # st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
462
+ # dynamic_weight_options = ['Grid Search', 'SVM', 'Greedy']
463
+ # dynamic_weight_panel = st.columns(len(dynamic_weight_options))
464
+ #
465
+ # if len(st.session_state.selected_dict.get(prompt_id, [])) > 0:
466
+ # btn_disable = False
467
+ # else:
468
+ # btn_disable = True
469
+ #
470
+ # for i in range(len(dynamic_weight_options)):
471
+ # method = dynamic_weight_options[i]
472
+ # with dynamic_weight_panel[i]:
473
+ # btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items, method))
474
+
475
+ # prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}", disabled=False, key=f'{prompt_id}')
476
+ # if prompt:
477
+ # switch_page("ranking")
478
+
479
+ # with st.form(key=f'{prompt_id}'):
480
+ # buttons = st.columns([1, 1, 1])
481
+ # buttons_space = st.columns([1, 1, 1])
482
+ gallery_space = st.empty()
483
+
484
+ # with buttons_space[0]:
485
+ # continue_btn = st.button('Proceed selections to ranking', use_container_width=True, type='primary')
486
+ # if continue_btn:
487
+ # # self.submit_actions('Continue', prompt_id)
488
+ # switch_page("ranking")
489
+ #
490
+ # with buttons_space[1]:
491
+ # deselect_btn = st.button('Deselect All', use_container_width=True)
492
+ # if deselect_btn:
493
+ # self.submit_actions('Deselect', prompt_id)
494
+ #
495
+ # with buttons_space[2]:
496
+ # refresh_btn = st.button('Refresh', on_click=gallery_space.empty, use_container_width=True)
497
+
498
+ with gallery_space.container():
499
+ self.gallery_standard(items, col_num, info)
500
+
501
+ def submit_actions(self, status, prompt_id):
502
+ # remove counter from session state
503
+ # st.session_state.pop('counter', None)
504
+ self.remove_ranking_states('prompt_id')
505
+ if status == 'Select':
506
+ modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
507
+ st.session_state.selected_dict[prompt_id] = modelVersions.tolist()
508
+ print(st.session_state.selected_dict, 'select')
509
+ st.experimental_rerun()
510
+ elif status == 'Deselect':
511
+ st.session_state.selected_dict[prompt_id] = []
512
+ print(st.session_state.selected_dict, 'deselect')
513
+ st.experimental_rerun()
514
+ # self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
515
+ elif status == 'Continue':
516
+ st.session_state.selected_dict[prompt_id] = []
517
+ for key in st.session_state:
518
+ keys = key.split('_')
519
+ if keys[0] == 'select' and keys[1] == str(prompt_id):
520
+ if st.session_state[key]:
521
+ st.session_state.selected_dict[prompt_id].append(int(keys[2]))
522
+ # switch_page("ranking")
523
+ print(st.session_state.selected_dict, 'continue')
524
+ # st.experimental_rerun()
525
+
526
+ def dynamic_weight(self, prompt_id, items, method='Grid Search'):
527
+ selected = items[
528
+ items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
529
+ optimal_weight = [0, 0, 0]
530
+
531
+ if method == 'Grid Search':
532
+ # grid search method
533
+ top_ranking = len(items) * len(selected)
534
+
535
+ for clip_weight in np.arange(-1, 1, 0.1):
536
+ for mcos_weight in np.arange(-1, 1, 0.1):
537
+ for pop_weight in np.arange(-1, 1, 0.1):
538
+
539
+ weight_all = clip_weight*items[f'norm_clip'] + mcos_weight*items[f'norm_mcos'] + pop_weight*items['norm_pop']
540
+ weight_all_sorted = weight_all.sort_values(ascending=False).reset_index(drop=True)
541
+ # print('weight_all_sorted:', weight_all_sorted)
542
+ weight_selected = clip_weight*selected[f'norm_clip'] + mcos_weight*selected[f'norm_mcos'] + pop_weight*selected['norm_pop']
543
+
544
+ # get the index of values of weight_selected in weight_all_sorted
545
+ rankings = []
546
+ for weight in weight_selected:
547
+ rankings.append(weight_all_sorted.index[weight_all_sorted == weight].tolist()[0])
548
+ if sum(rankings) <= top_ranking:
549
+ top_ranking = sum(rankings)
550
+ print('current top ranking:', top_ranking, rankings)
551
+ optimal_weight = [clip_weight, mcos_weight, pop_weight]
552
+ print('optimal weight:', optimal_weight)
553
+
554
+ elif method == 'SVM':
555
+ # svm method
556
+ print('start svm method')
557
+ # get residual dataframe that contains models not selected
558
+ residual = items[~items['modelVersion_id'].isin(selected['modelVersion_id'])].reset_index(drop=True)
559
+ residual = residual[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
560
+ residual = residual.to_numpy()
561
+ selected = selected[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
562
+ selected = selected.to_numpy()
563
+
564
+ y = np.concatenate((np.full((len(selected), 1), -1), np.full((len(residual), 1), 1)), axis=0).ravel()
565
+ X = np.concatenate((selected, residual), axis=0)
566
+
567
+ # fit svm model, and get parameters for the hyperplane
568
+ clf = LinearSVC(random_state=0, C=1.0, fit_intercept=False, dual='auto')
569
+ clf.fit(X, y)
570
+ optimal_weight = clf.coef_[0].tolist()
571
+ print('optimal weight:', optimal_weight)
572
+ pass
573
+
574
+ elif method == 'Greedy':
575
+ for idx in selected.index:
576
+ # find which score is the highest, clip, mcos, or pop
577
+ clip_score = selected.loc[idx, 'norm_clip_crop']
578
+ mcos_score = selected.loc[idx, 'norm_mcos_crop']
579
+ pop_score = selected.loc[idx, 'norm_pop']
580
+ if clip_score >= mcos_score and clip_score >= pop_score:
581
+ optimal_weight[0] += 1
582
+ elif mcos_score >= clip_score and mcos_score >= pop_score:
583
+ optimal_weight[1] += 1
584
+ elif pop_score >= clip_score and pop_score >= mcos_score:
585
+ optimal_weight[2] += 1
586
+
587
+ # normalize optimal_weight
588
+ optimal_weight = [round(weight/len(selected), 2) for weight in optimal_weight]
589
+ print('optimal weight:', optimal_weight)
590
+ print('optimal weight:', optimal_weight)
591
+
592
+ st.session_state.score_weights[0: 3] = optimal_weight
593
+
594
+
595
+ def remove_ranking_states(self, prompt_id):
596
+ # for drag sort
597
+ try:
598
+ st.session_state.counter[prompt_id] = 0
599
+ st.session_state.ranking[prompt_id] = {}
600
+ print('remove ranking states')
601
+ except:
602
+ print('no sort ranking states to remove')
603
+
604
+ # for battles
605
+ try:
606
+ st.session_state.pointer[prompt_id] = {'left': 0, 'right': 1}
607
+ print('remove battles states')
608
+ except:
609
+ print('no battles states to remove')
610
+
611
+ # for page progress
612
+ try:
613
+ st.session_state.progress[prompt_id] = 'ranking'
614
+ print('reset page progress states')
615
+ except:
616
+ print('no page progress states to be reset')
617
+
618
+
619
+ # hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["x"])
620
+ @st.cache_resource
621
+ def altair_histogram(hist_data, sort_by, mini, maxi):
622
+ brushed = alt.selection_interval(encodings=['x'], name="brushed")
623
+
624
+ chart = (
625
+ alt.Chart(hist_data)
626
+ .mark_bar(opacity=0.7, cornerRadius=2)
627
+ .encode(alt.X(f"{sort_by}:Q", bin=alt.Bin(maxbins=25)), y="count()")
628
+ # .add_selection(brushed)
629
+ # .properties(width=800, height=300)
630
+ )
631
+
632
+ # Create a transparent rectangle for highlighting the range
633
+ highlight = (
634
+ alt.Chart(pd.DataFrame({'x1': [mini], 'x2': [maxi]}))
635
+ .mark_rect(opacity=0.3)
636
+ .encode(x='x1', x2='x2')
637
+ # .properties(width=800, height=300)
638
+ )
639
+
640
+ # Layer the chart and the highlight rectangle
641
+ layered_chart = alt.layer(chart, highlight)
642
+
643
+ return layered_chart
644
+
645
+
646
+ @st.cache_data
647
+ def load_hf_dataset(show_NSFW=False):
648
+ # login to huggingface
649
+ login(token=os.environ.get("HF_TOKEN"))
650
+
651
+ # load from huggingface
652
+ roster = pd.DataFrame(load_dataset('MAPS-research/GEMRec-Roster', split='train'))
653
+ promptBook = pd.DataFrame(load_dataset('MAPS-research/GEMRec-Metadata', split='train'))
654
+ # images_ds = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
655
+ images_ds = None # set to None for now since we use s3 bucket to store images
656
+
657
+ # # process dataset
658
+ # roster = roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name',
659
+ # 'model_download_count']].drop_duplicates().reset_index(drop=True)
660
+
661
+ # add 'custom_score_weights' column to promptBook if not exist
662
+ if 'weighted_score_sum' not in promptBook.columns:
663
+ promptBook.loc[:, 'weighted_score_sum'] = 0
664
+
665
+ # merge roster and promptbook
666
+ promptBook = promptBook.merge(roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']],
667
+ on=['model_id', 'modelVersion_id'], how='left')
668
+
669
+ # add column to record current row index
670
+ promptBook.loc[:, 'row_idx'] = promptBook.index
671
+
672
+ # apply a nsfw filter
673
+ if not show_NSFW:
674
+ promptBook = promptBook[promptBook['norm_nsfw'] <= 0.8].reset_index(drop=True)
675
+ print('nsfw filter applied', len(promptBook))
676
+
677
+ # add a column that adds up 'norm_clip', 'norm_mcos', and 'norm_pop'
678
+ score_weights = [1.0, 0.8, 0.2]
679
+ promptBook.loc[:, 'total_score'] = round(promptBook['norm_clip'] * score_weights[0] + promptBook['norm_mcos'] * score_weights[1] + promptBook['norm_pop'] * score_weights[2], 4)
680
+
681
+ return roster, promptBook, images_ds
682
+
683
+ @st.cache_data
684
+ def load_tsne_coordinates(items):
685
+ # load tsne coordinates
686
+ tsne_df = pd.read_parquet('./data/feats_tsne.parquet')
687
+
688
+ # print(tsne_df['modelVersion_id'].dtype)
689
+
690
+ # print('before merge:', items)
691
+ items = items.merge(tsne_df, on=['modelVersion_id', 'prompt_id'], how='left')
692
+ # print('after merge:', items)
693
+ return items
694
+
695
+
696
+ if __name__ == "__main__":
697
+ st.set_page_config(page_title="Model Coffer Gallery", page_icon="🖼️", layout="wide")
698
+
699
+ if 'user_id' not in st.session_state:
700
+ st.warning('Please log in first.')
701
+ home_btn = st.button('Go to Home Page')
702
+ if home_btn:
703
+ switch_page("home")
704
+ else:
705
+ # st.write('You have already logged in as ' + st.session_state.user_id[0])
706
+ roster, promptBook, images_ds = load_hf_dataset(st.session_state.show_NSFW)
707
+ # print(promptBook.columns)
708
+
709
+ # # initialize selected_dict
710
+ # if 'selected_dict' not in st.session_state:
711
+ # st.session_state['selected_dict'] = {}
712
+
713
+ app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
714
+ app.app()
715
+
716
+ with open('./css/style.css') as f:
717
+ st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
718
+