Ricercar commited on
Commit
bca2bcb
โ€ข
1 Parent(s): 2c3dcf3

new version! multiple pages!

Browse files
app.py โ†’ Archive/app.py RENAMED
@@ -100,10 +100,10 @@ class GalleryApp:
100
 
101
  st.image(image, use_column_width=True)
102
 
103
- # # show checkbox
104
- # self.promptBook.loc[items.iloc[idx + j]['row_idx'].item(), 'checked'] = st.checkbox(
105
- # 'Select', value=self.promptBook.loc[items.iloc[idx + j]['row_idx'].item(), 'checked'],
106
- # key=f'select_{idx + j}')
107
 
108
  # st.write(idx+j)
109
  # show selected info
@@ -294,8 +294,8 @@ class GalleryApp:
294
  return items, info, col_num
295
 
296
  def app(self):
297
- st.title('Model Coffer Gallery')
298
- st.write('This is a gallery of images generated by the models in the Model Coffer')
299
 
300
  with st.sidebar:
301
  prompt_tags = self.promptBook['tag'].unique()
@@ -367,7 +367,7 @@ class GalleryApp:
367
 
368
  with st.form(key=f'{prompt_id}', clear_on_submit=True):
369
  # buttons = st.columns([1, 1, 1])
370
- buttons_space = st.container()
371
  gallery_space = st.empty()
372
  # with buttons[0]:
373
  # submit = st.form_submit_button('Save selections', on_click=self.save_checked, use_container_width=True, type='primary')
@@ -379,8 +379,17 @@ class GalleryApp:
379
  with gallery_space.container():
380
  self.gallery_standard(items, col_num, info)
381
 
382
- with buttons_space:
383
- st.form_submit_button('Refresh', on_click=gallery_space.empty, use_container_width=True, type='primary')
 
 
 
 
 
 
 
 
 
384
 
385
 
386
  def reset_current_prompt(self, prompt_id):
@@ -416,7 +425,7 @@ def load_hf_dataset():
416
  # load from huggingface
417
  roster = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferRoster', split='train'))
418
  promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train'))
419
- images_ds = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
420
 
421
  # process dataset
422
  roster = roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name',
 
100
 
101
  st.image(image, use_column_width=True)
102
 
103
+ # show checkbox
104
+ self.promptBook.loc[items.iloc[idx + j]['row_idx'].item(), 'checked'] = st.checkbox(
105
+ 'Select', value=self.promptBook.loc[items.iloc[idx + j]['row_idx'].item(), 'checked'],
106
+ key=f'select_{idx + j}')
107
 
108
  # st.write(idx+j)
109
  # show selected info
 
294
  return items, info, col_num
295
 
296
  def app(self):
297
+ st.title('Model Visualization and Retrieval')
298
+ st.write('This is a gallery of images generated by the models')
299
 
300
  with st.sidebar:
301
  prompt_tags = self.promptBook['tag'].unique()
 
367
 
368
  with st.form(key=f'{prompt_id}', clear_on_submit=True):
369
  # buttons = st.columns([1, 1, 1])
370
+ buttons_space = st.columns([1, 1, 1, 1])
371
  gallery_space = st.empty()
372
  # with buttons[0]:
373
  # submit = st.form_submit_button('Save selections', on_click=self.save_checked, use_container_width=True, type='primary')
 
379
  with gallery_space.container():
380
  self.gallery_standard(items, col_num, info)
381
 
382
+ with buttons_space[0]:
383
+ st.form_submit_button('Confirm and Continue', use_container_width=True, type='primary')
384
+
385
+ with buttons_space[1]:
386
+ st.form_submit_button('Select All', use_container_width=True)
387
+
388
+ with buttons_space[2]:
389
+ st.form_submit_button('Deselect All', use_container_width=True)
390
+
391
+ with buttons_space[3]:
392
+ st.form_submit_button('Refresh', on_click=gallery_space.empty, use_container_width=True)
393
 
394
 
395
  def reset_current_prompt(self, prompt_id):
 
425
  # load from huggingface
426
  roster = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferRoster', split='train'))
427
  promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train'))
428
+ images_ds = load_from_disk(os.path.join(os.getcwd(), '../data', 'promptbook'))
429
 
430
  # process dataset
431
  roster = roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name',
Archive/test.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ if __name__ == "__main__":
4
+ if 'check_dict' not in st.session_state:
5
+ st.session_state.check_dict = {'check1': False, 'check2': False, 'check3': False}
6
+
7
+ with st.form('my_form'):
8
+ st.session_state.check_dict['check1'] = st.checkbox('Check 1 out')
9
+ st.session_state.check_dict['check2'] = st.checkbox('Check 2 out')
10
+ st.session_state.check_dict['check3'] = st.checkbox('Check 3 out')
11
+
12
+ check21 = st.checkbox('Check 21 out')
13
+ if check21:
14
+ st.write('check21 is checked')
15
+ check22 = st.checkbox('Check 22 out')
16
+ if check22:
17
+ st.write('check22 is checked')
18
+ check23 = st.checkbox('Check 23 out')
19
+ if check23:
20
+ st.write('check23 is checked')
21
+
22
+ # Every form must have a submit button.
23
+ submitted = st.form_submit_button('Submit')
24
+
25
+ for key, value in st.session_state.check_dict.items():
26
+ st.write(key, value)
README.md CHANGED
@@ -6,7 +6,7 @@ colorTo: purple
6
  sdk: streamlit
7
  sdk_version: 1.19.0
8
  python_version: 3.9.13
9
- app_file: app.py
10
  pinned: false
11
  ---
12
 
 
6
  sdk: streamlit
7
  sdk_version: 1.19.0
8
  python_version: 3.9.13
9
+ app_file: ๐Ÿ _Home.py
10
  pinned: false
11
  ---
12
 
data/download_script.py CHANGED
@@ -20,5 +20,11 @@ def test():
20
  print(promptbook[0]['image'])
21
 
22
 
 
 
 
 
 
 
23
  if __name__ == '__main__':
24
- main()
 
20
  print(promptbook[0]['image'])
21
 
22
 
23
+ # def drop_metadata_checked_column():
24
+ # ModelCofferMetadata = load_dataset('NYUSHPRP/ModelCofferMetadata', split='train')
25
+ # ModelCofferMetadata = ModelCofferMetadata.remove_columns(['checked'])
26
+ # ModelCofferMetadata.push_to_hub('NYUSHPRP/ModelCofferMetadata', split='train')
27
+
28
+
29
  if __name__ == '__main__':
30
+ main()
pages/1_๐Ÿ–ผ๏ธ_Gallery.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import pandas as pd
4
+ import glob
5
+ from datasets import load_dataset, Dataset, load_from_disk
6
+ from huggingface_hub import login
7
+ import os
8
+ import requests
9
+ from bs4 import BeautifulSoup
10
+ import altair as alt
11
+ from streamlit_extras.switch_page_button import switch_page
12
+
13
+ SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'avg_rank', 'pop': 'model_download_count'}
14
+
15
+
16
+ # hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["x"])
17
+ @st.cache_resource
18
+ def altair_histogram(hist_data, sort_by, mini, maxi):
19
+ brushed = alt.selection_interval(encodings=['x'], name="brushed")
20
+
21
+ chart = (
22
+ alt.Chart(hist_data)
23
+ .mark_bar(opacity=0.7, cornerRadius=2)
24
+ .encode(alt.X(f"{sort_by}:Q", bin=alt.Bin(maxbins=25)), y="count()")
25
+ # .add_selection(brushed)
26
+ # .properties(width=800, height=300)
27
+ )
28
+
29
+ # Create a transparent rectangle for highlighting the range
30
+ highlight = (
31
+ alt.Chart(pd.DataFrame({'x1': [mini], 'x2': [maxi]}))
32
+ .mark_rect(opacity=0.3)
33
+ .encode(x='x1', x2='x2')
34
+ # .properties(width=800, height=300)
35
+ )
36
+
37
+ # Layer the chart and the highlight rectangle
38
+ layered_chart = alt.layer(chart, highlight)
39
+
40
+ return layered_chart
41
+
42
+
43
+ class GalleryApp:
44
+ def __init__(self, promptBook, images_ds):
45
+ self.promptBook = promptBook
46
+ self.images_ds = images_ds
47
+
48
+ def gallery_standard(self, items, col_num, info):
49
+ rows = len(items) // col_num + 1
50
+ containers = [st.container() for _ in range(rows)]
51
+ for idx in range(0, len(items), col_num):
52
+ row_idx = idx // col_num
53
+ with containers[row_idx]:
54
+ cols = st.columns(col_num)
55
+ for j in range(col_num):
56
+ if idx + j < len(items):
57
+ with cols[j]:
58
+ # show image
59
+ image = self.images_ds[items.iloc[idx + j]['row_idx'].item()]['image']
60
+ st.image(image, use_column_width=True)
61
+
62
+ # handel checkbox information
63
+ prompt_id = items.iloc[idx + j]['prompt_id']
64
+ modelVersion_id = items.iloc[idx + j]['modelVersion_id']
65
+ check_init = True if modelVersion_id in st.session_state.selected_dict.get(prompt_id, []) else False
66
+
67
+ # show checkbox
68
+ checked = st.checkbox('Select', key=f'select_{idx + j}', value=check_init)
69
+ if checked:
70
+ st.session_state.selected_dict[prompt_id] = st.session_state.selected_dict.get(prompt_id, []) + [modelVersion_id]
71
+ else:
72
+ try:
73
+ st.session_state.selected_dict[prompt_id].remove(modelVersion_id)
74
+ except:
75
+ pass
76
+
77
+ # show selected info
78
+ for key in info:
79
+ st.write(f"**{key}**: {items.iloc[idx + j][key]}")
80
+
81
+ def selection_panel(self, items):
82
+ selecters = st.columns([1, 4])
83
+
84
+ # select sort type
85
+ with selecters[0]:
86
+ sort_type = st.selectbox('Sort by', ['Scores', 'IDs and Names'])
87
+ if sort_type == 'Scores':
88
+ sort_by = 'weighted_score_sum'
89
+
90
+ # select other options
91
+ with selecters[1]:
92
+ if sort_type == 'IDs and Names':
93
+ sub_selecters = st.columns([3, 1])
94
+ # select sort by
95
+ with sub_selecters[0]:
96
+ sort_by = st.selectbox('Sort by',
97
+ ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id'],
98
+ label_visibility='hidden')
99
+
100
+ continue_idx = 1
101
+
102
+ else:
103
+ # add custom weights
104
+ sub_selecters = st.columns([1, 1, 1, 1])
105
+
106
+ if 'score_weights' not in st.session_state:
107
+ st.session_state.score_weights = [1.0, 0.8, 0.2, 0.84]
108
+
109
+ with sub_selecters[0]:
110
+ clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[0], step=0.1, help='the weight for normalized clip score')
111
+ with sub_selecters[1]:
112
+ rank_weight = st.number_input('Distinctiveness Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[1], step=0.1, help='the weight for average rank')
113
+ with sub_selecters[2]:
114
+ pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[2], step=0.1, help='the weight for normalized popularity score')
115
+
116
+ items.loc[:, 'weighted_score_sum'] = round(items['norm_clip'] * clip_weight + items['avg_rank'] * rank_weight + items[
117
+ 'norm_pop'] * pop_weight, 4)
118
+
119
+ continue_idx = 3
120
+
121
+ # select threshold
122
+ with sub_selecters[continue_idx]:
123
+ dist_threshold = st.number_input('Distinctiveness Threshold', min_value=0.0, max_value=1.0, value=st.session_state.score_weights[3], step=0.01, help='Only show models with distinctiveness score lower than this threshold, set 1.0 to show all images')
124
+ items = items[items['avg_rank'] < dist_threshold].reset_index(drop=True)
125
+
126
+ # save latest weights
127
+ st.session_state.score_weights = [clip_weight, rank_weight, pop_weight, dist_threshold]
128
+
129
+ # draw a distribution histogram
130
+ if sort_type == 'Scores':
131
+ try:
132
+ with st.expander('Show score distribution histogram and select score range'):
133
+ st.write('**Score distribution histogram**')
134
+ chart_space = st.container()
135
+ # st.write('Select the range of scores to show')
136
+ hist_data = pd.DataFrame(items[sort_by])
137
+ mini = hist_data[sort_by].min().item()
138
+ mini = mini//0.1 * 0.1
139
+ maxi = hist_data[sort_by].max().item()
140
+ maxi = maxi//0.1 * 0.1 + 0.1
141
+ st.write('**Select the range of scores to show**')
142
+ r = st.slider('Select the range of scores to show', min_value=mini, max_value=maxi, value=(mini, maxi), step=0.05, label_visibility='collapsed')
143
+ with chart_space:
144
+ st.altair_chart(altair_histogram(hist_data, sort_by, r[0], r[1]), use_container_width=True)
145
+ # event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
146
+ # r = event_dict.get(sort_by)
147
+ if r:
148
+ items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
149
+ # st.write(r)
150
+ except:
151
+ pass
152
+
153
+ display_options = st.columns([1, 4])
154
+
155
+ with display_options[0]:
156
+ # select order
157
+ order = st.selectbox('Order', ['Ascending', 'Descending'], index=1 if sort_type == 'Scores' else 0)
158
+ if order == 'Ascending':
159
+ order = True
160
+ else:
161
+ order = False
162
+
163
+ with display_options[1]:
164
+
165
+ # select info to show
166
+ info = st.multiselect('Show Info',
167
+ ['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
168
+ 'modelVersion_name', 'modelVersion_id', 'clip+rank', 'clip+pop', 'rank+pop',
169
+ 'clip+rank+pop', 'weighted_score_sum'],
170
+ default=sort_by)
171
+
172
+ # apply sorting to dataframe
173
+ items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
174
+
175
+ # select number of columns
176
+ col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
177
+
178
+ return items, info, col_num
179
+
180
+ def sidebar(self):
181
+ with st.sidebar:
182
+ prompt_tags = self.promptBook['tag'].unique()
183
+ # sort tags by alphabetical order
184
+ prompt_tags = np.sort(prompt_tags)[::-1]
185
+
186
+ tag = st.selectbox('Select a tag', prompt_tags)
187
+
188
+ items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
189
+
190
+ original_prompts = np.sort(items['prompt'].unique())[::-1]
191
+
192
+ # remove the first four items in the prompt, which are mostly the same
193
+ if tag != 'abstract':
194
+ prompts = [', '.join(x.split(', ')[4:]) for x in original_prompts]
195
+ prompt = st.selectbox('Select prompt', prompts)
196
+
197
+ idx = prompts.index(prompt)
198
+ prompt_full = ', '.join(original_prompts[idx].split(', ')[:4]) + ', ' + prompt
199
+ else:
200
+ prompt_full = st.selectbox('Select prompt', original_prompts)
201
+
202
+ items = items[items['prompt'] == prompt_full].reset_index(drop=True)
203
+ prompt_id = items['prompt_id'].unique()[0]
204
+
205
+ # show image metadata
206
+ image_metadatas = ['prompt_id', 'prompt', 'negativePrompt', 'sampler', 'cfgScale', 'size', 'seed']
207
+ for key in image_metadatas:
208
+ label = ' '.join(key.split('_')).capitalize()
209
+ st.write(f"**{label}**")
210
+ if items[key][0] == ' ':
211
+ st.write('`None`')
212
+ else:
213
+ st.caption(f"{items[key][0]}")
214
+
215
+ # for tag as civitai, add civitai reference
216
+ if tag == 'civitai':
217
+ try:
218
+ st.write('**Civitai Reference**')
219
+ res = requests.get(f'https://civitai.com/images/{prompt_id.item()}')
220
+ # st.write(res.text)
221
+ soup = BeautifulSoup(res.text, 'html.parser')
222
+ image_section = soup.find('div', {'class': 'mantine-12rlksp'})
223
+ image_url = image_section.find('img')['src']
224
+ st.image(image_url, use_column_width=True)
225
+ except:
226
+ pass
227
+
228
+ return prompt_tags, tag, prompt_id, items
229
+
230
+ def app(self):
231
+ st.title('Model Visualization and Retrieval')
232
+ st.write('This is a gallery of images generated by the models')
233
+
234
+ prompt_tags, tag, prompt_id, items = self.sidebar()
235
+
236
+ # add safety check for some prompts
237
+ safety_check = True
238
+ unsafe_prompts = {}
239
+ # initialize unsafe prompts
240
+ for prompt_tag in prompt_tags:
241
+ unsafe_prompts[prompt_tag] = []
242
+ # manually add unsafe prompts
243
+ unsafe_prompts['civitai'] = [375790, 366222, 295008, 256477]
244
+ unsafe_prompts['people'] = [53]
245
+ unsafe_prompts['art'] = [23]
246
+ unsafe_prompts['abstract'] = [10, 12]
247
+ unsafe_prompts['food'] = [34]
248
+
249
+ if int(prompt_id.item()) in unsafe_prompts[tag]:
250
+ st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
251
+ safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'{prompt_id}')
252
+
253
+ if safety_check:
254
+ items, info, col_num = self.selection_panel(items)
255
+ # self.gallery_standard(items, col_num, info)
256
+
257
+ with st.form(key=f'{prompt_id}'):
258
+ # buttons = st.columns([1, 1, 1])
259
+ buttons_space = st.columns([1, 1, 1, 1])
260
+ gallery_space = st.empty()
261
+
262
+ with buttons_space[0]:
263
+ continue_btn = st.form_submit_button('Confirm Selection', use_container_width=True, type='primary')
264
+ if continue_btn:
265
+ self.submit_actions('Continue', prompt_id)
266
+
267
+ with buttons_space[1]:
268
+ select_btn = st.form_submit_button('Select All', use_container_width=True)
269
+ if select_btn:
270
+ self.submit_actions('Select', prompt_id)
271
+
272
+ with buttons_space[2]:
273
+ deselect_btn = st.form_submit_button('Deselect All', use_container_width=True)
274
+ if deselect_btn:
275
+ self.submit_actions('Deselect', prompt_id)
276
+
277
+ with buttons_space[3]:
278
+ refresh_btn = st.form_submit_button('Refresh', on_click=gallery_space.empty, use_container_width=True)
279
+
280
+ with gallery_space.container():
281
+ with st.spinner('Loading images...'):
282
+ self.gallery_standard(items, col_num, info)
283
+
284
+ def submit_actions(self, status, prompt_id):
285
+ if status == 'Select':
286
+ modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
287
+ st.session_state.selected_dict[prompt_id] = modelVersions.tolist()
288
+ print(st.session_state.selected_dict, 'select')
289
+ elif status == 'Deselect':
290
+ st.session_state.selected_dict[prompt_id] = []
291
+ print(st.session_state.selected_dict, 'deselect')
292
+ # self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
293
+ pass
294
+ elif status == 'Continue':
295
+ # switch_page("ranking")
296
+ pass
297
+
298
+
299
+ @st.cache_data
300
+ def load_hf_dataset():
301
+ # login to huggingface
302
+ login(token=os.environ.get("HF_TOKEN"))
303
+
304
+ # load from huggingface
305
+ roster = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferRoster', split='train'))
306
+ promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train'))
307
+ images_ds = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
308
+
309
+ # process dataset
310
+ roster = roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name',
311
+ 'model_download_count']].drop_duplicates().reset_index(drop=True)
312
+
313
+ # # add 'checked' column to promptBook if not exist
314
+ # if 'checked' not in promptBook.columns:
315
+ # promptBook.loc[:, 'checked'] = False
316
+
317
+ # add 'custom_score_weights' column to promptBook if not exist
318
+ if 'weighted_score_sum' not in promptBook.columns:
319
+ promptBook.loc[:, 'weighted_score_sum'] = 0
320
+
321
+ # merge roster and promptbook
322
+ promptBook = promptBook.merge(roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']],
323
+ on=['model_id', 'modelVersion_id'], how='left')
324
+
325
+ # add column to record current row index
326
+ promptBook.loc[:, 'row_idx'] = promptBook.index
327
+
328
+ return roster, promptBook, images_ds
329
+
330
+
331
+ if __name__ == "__main__":
332
+ st.set_page_config(page_title="Model Coffer Gallery", page_icon="๐Ÿ–ผ๏ธ", layout="wide")
333
+ if 'user_id' not in st.session_state:
334
+ st.warning('Please log in first.')
335
+ home_btn = st.button('Go to Home Page')
336
+ if home_btn:
337
+ switch_page("home")
338
+ else:
339
+ st.write('You have already logged in as ' + st.session_state.user_id[0])
340
+ roster, promptBook, st.session_state["images_ds"] = load_hf_dataset()
341
+ # print(promptBook.columns)
342
+
343
+ # initialize selected_dict
344
+ if 'selected_dict' not in st.session_state:
345
+ st.session_state['selected_dict'] = {}
346
+
347
+ app = GalleryApp(promptBook=promptBook, images_ds=st.session_state.images_ds)
348
+ app.app()
pages/2_๐ŸŽ–๏ธ_Ranking.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import pandas as pd
4
+ from streamlit_extras.switch_page_button import switch_page
5
+
6
+ if __name__ == "__main__":
7
+ st.set_page_config(page_title="Personal Image Ranking", page_icon="๐ŸŽ–๏ธ๏ธ", layout="wide")
8
+
9
+ if 'user_id' not in st.session_state:
10
+ st.warning('Please log in first.')
11
+ home_btn = st.button('Go to Home Page')
12
+ if home_btn:
13
+ switch_page("home")
14
+
15
+ else:
16
+ all_checked = 0
17
+ for key, value in st.session_state.selected_dict.items():
18
+ for v in value:
19
+ all_checked += 1
20
+
21
+ if all_checked == 0:
22
+ st.info('You have not checked any image yet. Please go back to the gallery page and check some images.')
23
+ gallery_btn = st.button('Go to Gallery')
24
+ if gallery_btn:
25
+ switch_page('gallery')
26
+ else:
27
+ st.write('You have checked ' + str(all_checked) + ' images.')
28
+
๐Ÿ _Home.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import random
3
+ import time
4
+ from streamlit_extras.switch_page_button import switch_page
5
+
6
+
7
+ def login():
8
+ # skip customize user name for debug mode
9
+
10
+ with st.form("user_login"):
11
+ st.write('## Enter Your Name')
12
+ user_id = st.text_input(
13
+ "Enter your name for personalization ๐Ÿ‘‡",
14
+ label_visibility='visible',
15
+ disabled=False,
16
+ placeholder='anonymous',
17
+ )
18
+ st.write('You can leave it blank to be anonymous.')
19
+
20
+ # Every form must have a submit button.
21
+ submitted = st.form_submit_button("Start")
22
+ if submitted:
23
+ save_user_id(user_id)
24
+ switch_page("gallery")
25
+
26
+
27
+ def save_user_id(user_id):
28
+ print(user_id)
29
+ if not user_id:
30
+ user_id = 'anonymous' + str(random.randint(0, 100000))
31
+ st.session_state.user_id = [user_id, time.time()]
32
+
33
+
34
+ if __name__ == '__main__':
35
+ st.set_page_config(page_title="Login", page_icon="๐Ÿ ")
36
+
37
+ st.title("Personalized Image Ranking")
38
+ st.write(
39
+ "This is an web application to collect personal preference to ai generated images. \
40
+ You can know which model you like most after you finish the survey."
41
+ )
42
+
43
+ if 'user_id' not in st.session_state:
44
+ login()
45
+ else:
46
+ st.write('You have already logged in as ' + st.session_state.user_id[0])
47
+ st.button('Log out', on_click=lambda: st.session_state.pop('user_id'))
48
+