Ricercar commited on
Commit
00a6576
1 Parent(s): 40c1b17

ranking is not finished!!!

Browse files
pages/Gallery.py CHANGED
@@ -45,19 +45,7 @@ class GalleryApp:
45
  st.write("Position: ", idx + j)
46
 
47
  # show checkbox
48
- checked = st.checkbox('Select', key=f'select_{prompt_id}_{modelVersion_id}', value=check_init)
49
-
50
- #
51
- # if checked:
52
- # if prompt_id not in st.session_state.selected_dict:
53
- # st.session_state.selected_dict[prompt_id] = []
54
- # if modelVersion_id not in st.session_state.selected_dict[prompt_id]:
55
- # st.session_state.selected_dict[prompt_id].append(modelVersion_id)
56
- # else:
57
- # try:
58
- # st.session_state.selected_dict[prompt_id].remove(modelVersion_id)
59
- # except:
60
- # pass
61
 
62
  # show selected info
63
  for key in info:
@@ -65,7 +53,6 @@ class GalleryApp:
65
 
66
  def selection_panel(self, items):
67
  # temperal function
68
- preprocessor = st.radio('Preprocess Method', ['crop', 'resize'], horizontal=True)
69
 
70
  selecters = st.columns([1, 4])
71
 
@@ -101,7 +88,7 @@ class GalleryApp:
101
  with sub_selecters[2]:
102
  pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[2], step=0.1, help='the weight for normalized popularity score')
103
 
104
- items.loc[:, 'weighted_score_sum'] = round(items[f'norm_clip_{preprocessor}'] * clip_weight + items[f'norm_mcos_{preprocessor}'] * mcos_weight + items[
105
  'norm_pop'] * pop_weight, 4)
106
 
107
  continue_idx = 3
@@ -168,7 +155,7 @@ class GalleryApp:
168
  # select number of columns
169
  col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
170
 
171
- return items, info, col_num, preprocessor
172
 
173
  def sidebar(self):
174
  with st.sidebar:
@@ -226,7 +213,7 @@ class GalleryApp:
226
  st.title('Model Visualization and Retrieval')
227
  st.write('This is a gallery of images generated by the models')
228
 
229
- prompt_tags, tag, prompt_id, items= self.sidebar()
230
 
231
  # add safety check for some prompts
232
  safety_check = True
@@ -245,7 +232,7 @@ class GalleryApp:
245
  safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'{prompt_id}')
246
 
247
  if safety_check:
248
- items, info, col_num, preprocessor = self.selection_panel(items)
249
 
250
  if 'selected_dict' in st.session_state:
251
  st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
@@ -260,7 +247,7 @@ class GalleryApp:
260
  for i in range(len(dynamic_weight_options)):
261
  method = dynamic_weight_options[i]
262
  with dynamic_weight_panel[i]:
263
- btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items, preprocessor, method))
264
 
265
  with st.form(key=f'{prompt_id}'):
266
  # buttons = st.columns([1, 1, 1])
@@ -311,7 +298,7 @@ class GalleryApp:
311
  print(st.session_state.selected_dict, 'continue')
312
  st.experimental_rerun()
313
 
314
- def dynamic_weight(self, prompt_id, items, preprocessor='crop', method='Grid Search'):
315
  selected = items[
316
  items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
317
  optimal_weight = [0, 0, 0]
@@ -324,10 +311,10 @@ class GalleryApp:
324
  for mcos_weight in np.arange(-1, 1, 0.1):
325
  for pop_weight in np.arange(-1, 1, 0.1):
326
 
327
- weight_all = clip_weight*items[f'norm_clip_{preprocessor}'] + mcos_weight*items[f'norm_mcos_{preprocessor}'] + pop_weight*items['norm_pop']
328
  weight_all_sorted = weight_all.sort_values(ascending=False).reset_index(drop=True)
329
  # print('weight_all_sorted:', weight_all_sorted)
330
- weight_selected = clip_weight*selected[f'norm_clip_{preprocessor}'] + mcos_weight*selected[f'norm_mcos_{preprocessor}'] + pop_weight*selected['norm_pop']
331
 
332
  # get the index of values of weight_selected in weight_all_sorted
333
  rankings = []
@@ -438,6 +425,9 @@ def load_hf_dataset():
438
  if __name__ == "__main__":
439
  st.set_page_config(page_title="Model Coffer Gallery", page_icon="🖼️", layout="wide")
440
 
 
 
 
441
  if 'user_id' not in st.session_state:
442
  st.warning('Please log in first.')
443
  home_btn = st.button('Go to Home Page')
 
45
  st.write("Position: ", idx + j)
46
 
47
  # show checkbox
48
+ st.checkbox('Select', key=f'select_{prompt_id}_{modelVersion_id}', value=check_init)
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  # show selected info
51
  for key in info:
 
53
 
54
  def selection_panel(self, items):
55
  # temperal function
 
56
 
57
  selecters = st.columns([1, 4])
58
 
 
88
  with sub_selecters[2]:
89
  pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[2], step=0.1, help='the weight for normalized popularity score')
90
 
91
+ items.loc[:, 'weighted_score_sum'] = round(items[f'norm_clip'] * clip_weight + items[f'norm_mcos'] * mcos_weight + items[
92
  'norm_pop'] * pop_weight, 4)
93
 
94
  continue_idx = 3
 
155
  # select number of columns
156
  col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
157
 
158
+ return items, info, col_num
159
 
160
  def sidebar(self):
161
  with st.sidebar:
 
213
  st.title('Model Visualization and Retrieval')
214
  st.write('This is a gallery of images generated by the models')
215
 
216
+ prompt_tags, tag, prompt_id, items = self.sidebar()
217
 
218
  # add safety check for some prompts
219
  safety_check = True
 
232
  safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'{prompt_id}')
233
 
234
  if safety_check:
235
+ items, info, col_num = self.selection_panel(items)
236
 
237
  if 'selected_dict' in st.session_state:
238
  st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
 
247
  for i in range(len(dynamic_weight_options)):
248
  method = dynamic_weight_options[i]
249
  with dynamic_weight_panel[i]:
250
+ btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items, method))
251
 
252
  with st.form(key=f'{prompt_id}'):
253
  # buttons = st.columns([1, 1, 1])
 
298
  print(st.session_state.selected_dict, 'continue')
299
  st.experimental_rerun()
300
 
301
+ def dynamic_weight(self, prompt_id, items, method='Grid Search'):
302
  selected = items[
303
  items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
304
  optimal_weight = [0, 0, 0]
 
311
  for mcos_weight in np.arange(-1, 1, 0.1):
312
  for pop_weight in np.arange(-1, 1, 0.1):
313
 
314
+ weight_all = clip_weight*items[f'norm_clip'] + mcos_weight*items[f'norm_mcos'] + pop_weight*items['norm_pop']
315
  weight_all_sorted = weight_all.sort_values(ascending=False).reset_index(drop=True)
316
  # print('weight_all_sorted:', weight_all_sorted)
317
+ weight_selected = clip_weight*selected[f'norm_clip'] + mcos_weight*selected[f'norm_mcos'] + pop_weight*selected['norm_pop']
318
 
319
  # get the index of values of weight_selected in weight_all_sorted
320
  rankings = []
 
425
  if __name__ == "__main__":
426
  st.set_page_config(page_title="Model Coffer Gallery", page_icon="🖼️", layout="wide")
427
 
428
+ # remove ranking in the session state if it is created in Ranking.py
429
+ st.session_state.pop('ranking', None)
430
+
431
  if 'user_id' not in st.session_state:
432
  st.warning('Please log in first.')
433
  home_btn = st.button('Go to Home Page')
pages/Ranking.py CHANGED
@@ -8,24 +8,105 @@ from streamlit_extras.switch_page_button import switch_page
8
  from pages.Gallery import load_hf_dataset
9
 
10
 
11
- class RankingApp():
12
- def __init__(self, promptBook, images_ds):
13
  self.promptBook = promptBook
14
- self.images_ds = images_ds
 
 
 
15
 
16
- def draggable_images(self, items, layout='vertical'):
17
- pass
18
 
19
  def sidebar(self):
20
  with st.sidebar:
21
  prompt_tags = self.promptBook['tag'].unique()
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  def app(self):
25
  st.title('Personal Image Ranking')
26
  st.write('Here you can test out your selected images with any prompt you like.')
 
27
 
28
- prompt_tags, tag, prompt_id, items= self.sidebar()
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  if __name__ == "__main__":
@@ -50,11 +131,22 @@ if __name__ == "__main__":
50
  if gallery_btn:
51
  switch_page('gallery')
52
  else:
53
- st.write('You have checked ' + str(len(selected_modelVersions)) + ' images.')
54
  roster, promptBook, images_ds = load_hf_dataset()
55
- st.write("## roster")
56
- st.write(roster[roster['modelVersion_id'].isin(selected_modelVersions)])
 
57
  # st.write(roster)
58
  # st.write("## promptBook")
59
  # st.write(promptBook)
60
 
 
 
 
 
 
 
 
 
 
 
 
8
  from pages.Gallery import load_hf_dataset
9
 
10
 
11
+ class RankingApp:
12
+ def __init__(self, promptBook, images_endpoint, batch_size=4):
13
  self.promptBook = promptBook
14
+ self.images_endpoint = images_endpoint
15
+ self.batch_size = batch_size
16
+ # self.batch_num = len(self.promptBook) // self.batch_size
17
+ # self.batch_num += 1 if len(self.promptBook) % self.batch_size != 0 else 0
18
 
19
+ if 'counter' not in st.session_state:
20
+ st.session_state.counter = 0
21
 
22
  def sidebar(self):
23
  with st.sidebar:
24
  prompt_tags = self.promptBook['tag'].unique()
25
+ prompt_tags = np.sort(prompt_tags)
26
 
27
+ tag = st.selectbox('Select a prompt tag', prompt_tags)
28
+ items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
29
+ prompts = np.sort(items['prompt'].unique())[::-1]
30
+
31
+ selected_prompt = st.selectbox('Select a prompt', prompts)
32
+
33
+ items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
34
+ prompt_id = items['prompt_id'].unique()[0]
35
+
36
+ with st.form(key='prompt_form'):
37
+ # input image metadata
38
+ prompt = st.text_area('Prompt', selected_prompt, height=150, key='prompt', disabled=True)
39
+ negative_prompt = st.text_area('Negative Prompt', items['negativePrompt'].unique()[0], height=150, key='negative_prompt', disabled=True)
40
+ st.form_submit_button('Generate Images', type='primary', use_container_width=True)
41
+
42
+ return prompt_tags, tag, prompt_id, items
43
+
44
+ def draggable_images(self, items, layout='portrait'):
45
+ # init ranking by the order of items
46
+ if 'ranking' not in st.session_state:
47
+ st.session_state.ranking = {}
48
+ for i in range(len(items)):
49
+ st.session_state.ranking[str(items['image_id'][i])] = i
50
+
51
+ print(items)
52
+ with elements('dashboard'):
53
+ if layout == 'portrait':
54
+ col_num = 4
55
+ layout = [dashboard.Item(str(items['image_id'][i]), i % col_num, i//col_num, 1, 2, isResizable=False) for i in range(len(items))]
56
+
57
+ elif layout == 'landscape':
58
+ col_num = 2
59
+ layout = [
60
+ dashboard.Item(str(items['image_id'][i]), i % col_num * 2, i // col_num, 2, 1.4, isResizable=False) for
61
+ i in range(len(items))
62
+ ]
63
+
64
+ with dashboard.Grid(layout, cols={'lg': 4, 'md': 4, 'sm': 4, 'xs': 4, 'xxs': 2}, onLayoutChange=self.handle_layout_change, margin=[18, 18], containerPadding=[0, 0]):
65
+ for i in range(len(layout)):
66
+ with mui.Card(key=str(items['image_id'][i]), variant="outlined"):
67
+ rank = st.session_state.ranking[str(items['image_id'][i])] + 1
68
+
69
+ mui.Chip(label=rank,
70
+ # variant="outlined" if rank!=1 else "default",
71
+ color="primary" if rank == 1 else "warning" if rank == 2 else "info",
72
+ size="small",
73
+ sx={"position": "absolute", "left": "-0.3rem", "top": "-0.3rem"})
74
+
75
+ img_url = self.images_endpoint + str(items['image_id'][i]) + '.png'
76
+
77
+ mui.CardMedia(
78
+ component="img",
79
+ # image={"data:image/png;base64", img_str},
80
+ image=img_url,
81
+ alt="There should be an image",
82
+ sx={"height": "100%", "object-fit": "fit", 'bgcolor': 'black'},
83
+ )
84
+
85
+ def handle_layout_change(self, updated_layout):
86
+ # print(updated_layout)
87
+ sorted_list = sorted(updated_layout, key=lambda x: (x['y'], x['x']))
88
+ sorted_list = [str(item['i']) for item in sorted_list]
89
+
90
+ for k in st.session_state.ranking.keys():
91
+ st.session_state.ranking[k] = sorted_list.index(k)
92
 
93
  def app(self):
94
  st.title('Personal Image Ranking')
95
  st.write('Here you can test out your selected images with any prompt you like.')
96
+ # st.write(self.promptBook)
97
 
98
+ prompt_tags, tag, prompt_id, items = self.sidebar()
99
+
100
+ sorting, control = st.columns((11, 1), gap='large')
101
+ with sorting:
102
+ # st.write('## Sorting')
103
+ # st.write('Please drag the images to sort them.')
104
+ st.progress((st.session_state.counter + 1) / self.batch_num, text=f"Batch {st.session_state.counter + 1} / {self.batch_num}")
105
+ self.draggable_images(items.iloc[self.batch_size*st.session_state.counter: self.batch_size*(st.session_state.counter+1)], layout='portrait')
106
+
107
+ with control:
108
+ st.button(":arrow_right:")
109
+ st.button(":slightly_frowning_face:")
110
 
111
 
112
  if __name__ == "__main__":
 
131
  if gallery_btn:
132
  switch_page('gallery')
133
  else:
134
+ # st.write('You have checked ' + str(len(selected_modelVersions)) + ' images.')
135
  roster, promptBook, images_ds = load_hf_dataset()
136
+ print(st.session_state.selected_dict)
137
+ # st.write("## roster")
138
+ # st.write(roster[roster['modelVersion_id'].isin(selected_modelVersions)])
139
  # st.write(roster)
140
  # st.write("## promptBook")
141
  # st.write(promptBook)
142
 
143
+ # only select the part of the promptbook where tag is the same as st.session_state.selected_dict.keys(), while model version ids are the same as corresponding values to each key
144
+ promptBook_selected = pd.DataFrame()
145
+ for key, value in st.session_state.selected_dict.items():
146
+ promptBook_selected = promptBook_selected.append(promptBook[(promptBook['prompt_id'] == key) & (promptBook['modelVersion_id'].isin(value))])
147
+ promptBook_selected = promptBook_selected.reset_index(drop=True)
148
+ images_endpoint = "https://modelcofferbucket.s3-accelerate.amazonaws.com/"
149
+
150
+ app = RankingApp(promptBook_selected, images_endpoint, batch_size=4)
151
+ app.app()
152
+
pages/__pycache__/Gallery.cpython-39.pyc CHANGED
Binary files a/pages/__pycache__/Gallery.cpython-39.pyc and b/pages/__pycache__/Gallery.cpython-39.pyc differ