Ricercar commited on
Commit
2a117d2
1 Parent(s): 80c61aa

first version of ranking page!

Browse files
Files changed (3) hide show
  1. data/ranking_script.py +16 -0
  2. pages/Gallery.py +4 -2
  3. pages/Ranking.py +119 -40
data/ranking_script.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import Dataset
2
+
3
+
4
+ def init_ranking_data():
5
+ ds = Dataset.from_dict({'image_id': [], 'modelVersion_id': [], 'ranking': [], "user_name": [], "timestamp": []})\
6
+
7
+ # add example data
8
+ # note that image_id is a string, other ids are int
9
+ ds = ds.add_item({'image_id': '0', 'modelVersion_id': 0, 'ranking': 0, "user_name": "example_data", "timestamp": 0.0})
10
+
11
+ ds.push_to_hub("MAPS-research/GEMRec-Ranking", split='train')
12
+
13
+
14
+ if __name__ == '__main__':
15
+ init_ranking_data()
16
+
pages/Gallery.py CHANGED
@@ -278,6 +278,8 @@ class GalleryApp:
278
  switch_page("ranking")
279
 
280
  def submit_actions(self, status, prompt_id):
 
 
281
  if status == 'Select':
282
  modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
283
  st.session_state.selected_dict[prompt_id] = modelVersions.tolist()
@@ -400,8 +402,8 @@ def load_hf_dataset():
400
  login(token=os.environ.get("HF_TOKEN"))
401
 
402
  # load from huggingface
403
- roster = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferRoster', split='train'))
404
- promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train'))
405
  # images_ds = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
406
  images_ds = None # set to None for now since we use s3 bucket to store images
407
 
 
278
  switch_page("ranking")
279
 
280
  def submit_actions(self, status, prompt_id):
281
+ # remove counter from session state
282
+ st.session_state.pop('counter', None)
283
  if status == 'Select':
284
  modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
285
  st.session_state.selected_dict[prompt_id] = modelVersions.tolist()
 
402
  login(token=os.environ.get("HF_TOKEN"))
403
 
404
  # load from huggingface
405
+ roster = pd.DataFrame(load_dataset('MAPS-research/GEMRec-Roster', split='train'))
406
+ promptBook = pd.DataFrame(load_dataset('MAPS-research/GEMRec-Metadata', split='train'))
407
  # images_ds = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
408
  images_ds = None # set to None for now since we use s3 bucket to store images
409
 
pages/Ranking.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import numpy as np
2
  import pandas as pd
3
  import streamlit as st
@@ -17,7 +18,7 @@ class RankingApp:
17
  # self.batch_num += 1 if len(self.promptBook) % self.batch_size != 0 else 0
18
 
19
  if 'counter' not in st.session_state:
20
- st.session_state.counter = 0
21
 
22
  def sidebar(self):
23
  with st.sidebar:
@@ -37,18 +38,27 @@ class RankingApp:
37
  # input image metadata
38
  prompt = st.text_area('Prompt', selected_prompt, height=150, key='prompt', disabled=True)
39
  negative_prompt = st.text_area('Negative Prompt', items['negativePrompt'].unique()[0], height=150, key='negative_prompt', disabled=True)
40
- st.form_submit_button('Generate Images', type='primary', use_container_width=True)
41
 
42
  return prompt_tags, tag, prompt_id, items
43
 
44
- def draggable_images(self, items, layout='portrait'):
45
  # init ranking by the order of items
 
46
  if 'ranking' not in st.session_state:
47
  st.session_state.ranking = {}
 
 
 
 
 
 
48
  for i in range(len(items)):
49
- st.session_state.ranking[str(items['image_id'][i])] = i
 
 
 
50
 
51
- print(items)
52
  with elements('dashboard'):
53
  if layout == 'portrait':
54
  col_num = 4
@@ -57,14 +67,17 @@ class RankingApp:
57
  elif layout == 'landscape':
58
  col_num = 2
59
  layout = [
60
- dashboard.Item(str(items['image_id'][i]), i % col_num * 2, i // col_num, 2, 1.4, isResizable=False) for
61
  i in range(len(items))
62
  ]
63
 
64
  with dashboard.Grid(layout, cols={'lg': 4, 'md': 4, 'sm': 4, 'xs': 4, 'xxs': 2}, onLayoutChange=self.handle_layout_change, margin=[18, 18], containerPadding=[0, 0]):
65
  for i in range(len(layout)):
66
  with mui.Card(key=str(items['image_id'][i]), variant="outlined"):
67
- rank = st.session_state.ranking[str(items['image_id'][i])] + 1
 
 
 
68
 
69
  mui.Chip(label=rank,
70
  # variant="outlined" if rank!=1 else "default",
@@ -79,7 +92,7 @@ class RankingApp:
79
  # image={"data:image/png;base64", img_str},
80
  image=img_url,
81
  alt="There should be an image",
82
- sx={"height": "100%", "object-fit": "fit", 'bgcolor': 'black'},
83
  )
84
 
85
  def handle_layout_change(self, updated_layout):
@@ -87,26 +100,95 @@ class RankingApp:
87
  sorted_list = sorted(updated_layout, key=lambda x: (x['y'], x['x']))
88
  sorted_list = [str(item['i']) for item in sorted_list]
89
 
90
- for k in st.session_state.ranking.keys():
91
- st.session_state.ranking[k] = sorted_list.index(k)
 
 
 
92
 
93
  def app(self):
94
  st.title('Personal Image Ranking')
95
  st.write('Here you can test out your selected images with any prompt you like.')
96
  # st.write(self.promptBook)
97
 
 
 
 
 
 
98
  prompt_tags, tag, prompt_id, items = self.sidebar()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- sorting, control = st.columns((11, 1), gap='large')
101
- with sorting:
102
- # st.write('## Sorting')
103
- # st.write('Please drag the images to sort them.')
104
- st.progress((st.session_state.counter + 1) / self.batch_num, text=f"Batch {st.session_state.counter + 1} / {self.batch_num}")
105
- self.draggable_images(items.iloc[self.batch_size*st.session_state.counter: self.batch_size*(st.session_state.counter+1)], layout='portrait')
106
 
107
- with control:
108
- st.button(":arrow_right:")
109
- st.button(":slightly_frowning_face:")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
 
112
  if __name__ == "__main__":
@@ -119,13 +201,15 @@ if __name__ == "__main__":
119
  switch_page("home")
120
 
121
  else:
122
- selected_modelVersions = []
 
123
  for key, value in st.session_state.selected_dict.items():
124
  for v in value:
125
- if v not in selected_modelVersions:
126
- selected_modelVersions.append(v)
 
127
 
128
- if len(selected_modelVersions) == 0:
129
  st.info('You have not checked any image yet. Please go back to the gallery page and check some images.')
130
  gallery_btn = st.button('Go to Gallery')
131
  if gallery_btn:
@@ -134,21 +218,16 @@ if __name__ == "__main__":
134
  # st.write('You have checked ' + str(len(selected_modelVersions)) + ' images.')
135
  roster, promptBook, images_ds = load_hf_dataset()
136
  print(st.session_state.selected_dict)
137
- st.write("# Full function is coming soon.")
138
- st.write("## roster")
139
- st.write(roster[roster['modelVersion_id'].isin(selected_modelVersions)])
140
-
141
- # st.write(roster)
142
- # st.write("## promptBook")
143
- # st.write(promptBook)
144
-
145
- # # only select the part of the promptbook where tag is the same as st.session_state.selected_dict.keys(), while model version ids are the same as corresponding values to each key
146
- # promptBook_selected = pd.DataFrame()
147
- # for key, value in st.session_state.selected_dict.items():
148
- # promptBook_selected = promptBook_selected.append(promptBook[(promptBook['prompt_id'] == key) & (promptBook['modelVersion_id'].isin(value))])
149
- # promptBook_selected = promptBook_selected.reset_index(drop=True)
150
- # images_endpoint = "https://modelcofferbucket.s3-accelerate.amazonaws.com/"
151
- #
152
- # app = RankingApp(promptBook_selected, images_endpoint, batch_size=4)
153
- # app.app()
154
 
 
1
+ import datasets
2
  import numpy as np
3
  import pandas as pd
4
  import streamlit as st
 
18
  # self.batch_num += 1 if len(self.promptBook) % self.batch_size != 0 else 0
19
 
20
  if 'counter' not in st.session_state:
21
+ st.session_state.counter = {}
22
 
23
  def sidebar(self):
24
  with st.sidebar:
 
38
  # input image metadata
39
  prompt = st.text_area('Prompt', selected_prompt, height=150, key='prompt', disabled=True)
40
  negative_prompt = st.text_area('Negative Prompt', items['negativePrompt'].unique()[0], height=150, key='negative_prompt', disabled=True)
41
+ st.form_submit_button('Generate Images [Coming Soon]', type='primary', use_container_width=True, disabled=True)
42
 
43
  return prompt_tags, tag, prompt_id, items
44
 
45
+ def draggable_images(self, items, prompt_id, layout='portrait'):
46
  # init ranking by the order of items
47
+
48
  if 'ranking' not in st.session_state:
49
  st.session_state.ranking = {}
50
+
51
+ if prompt_id not in st.session_state.ranking:
52
+ st.session_state.ranking[prompt_id] = {}
53
+
54
+ if st.session_state.counter[prompt_id] not in st.session_state.ranking[prompt_id]:
55
+ st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]] = {}
56
  for i in range(len(items)):
57
+ st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]][str(items['image_id'][i])] = i
58
+ else:
59
+ # set the index of items to the corresponding ranking value of the image_id
60
+ items.index = items['image_id'].apply(lambda x: st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]][str(x)])
61
 
 
62
  with elements('dashboard'):
63
  if layout == 'portrait':
64
  col_num = 4
 
67
  elif layout == 'landscape':
68
  col_num = 2
69
  layout = [
70
+ dashboard.Item(str(items['image_id'][i]), i % col_num * 2, i // col_num, 2, 1.6, isResizable=False) for
71
  i in range(len(items))
72
  ]
73
 
74
  with dashboard.Grid(layout, cols={'lg': 4, 'md': 4, 'sm': 4, 'xs': 4, 'xxs': 2}, onLayoutChange=self.handle_layout_change, margin=[18, 18], containerPadding=[0, 0]):
75
  for i in range(len(layout)):
76
  with mui.Card(key=str(items['image_id'][i]), variant="outlined"):
77
+ prompt_id = st.session_state.prompt_id_tmp
78
+ batch_idx = st.session_state.counter[prompt_id]
79
+
80
+ rank = st.session_state.ranking[prompt_id][batch_idx][str(items['image_id'][i])] + 1
81
 
82
  mui.Chip(label=rank,
83
  # variant="outlined" if rank!=1 else "default",
 
92
  # image={"data:image/png;base64", img_str},
93
  image=img_url,
94
  alt="There should be an image",
95
+ sx={"height": "100%", "object-fit": "contain", 'bgcolor': 'black'},
96
  )
97
 
98
  def handle_layout_change(self, updated_layout):
 
100
  sorted_list = sorted(updated_layout, key=lambda x: (x['y'], x['x']))
101
  sorted_list = [str(item['i']) for item in sorted_list]
102
 
103
+ prompt_id = st.session_state.prompt_id_tmp
104
+ batch_idx = st.session_state.counter[prompt_id]
105
+
106
+ for k in st.session_state.ranking[prompt_id][batch_idx].keys():
107
+ st.session_state.ranking[prompt_id][batch_idx][k] = sorted_list.index(k)
108
 
109
  def app(self):
110
  st.title('Personal Image Ranking')
111
  st.write('Here you can test out your selected images with any prompt you like.')
112
  # st.write(self.promptBook)
113
 
114
+ # save the current progress to session state
115
+ if 'progress' not in st.session_state:
116
+ st.session_state.progress = {}
117
+ # print('current progress: ', st.session_state.progress)
118
+
119
  prompt_tags, tag, prompt_id, items = self.sidebar()
120
+ batch_num = len(items) // self.batch_size
121
+ batch_num += 1 if len(items) % self.batch_size != 0 else 0
122
+
123
+ st.session_state.counter[prompt_id] = 0 if prompt_id not in st.session_state.counter else st.session_state.counter[prompt_id]
124
+
125
+ # save prompt_id in session state
126
+ st.session_state.prompt_id_tmp = prompt_id
127
+
128
+ if prompt_id not in st.session_state.progress:
129
+ st.session_state.progress[prompt_id] = 'ranking'
130
+
131
+ if st.session_state.progress[prompt_id] == 'ranking':
132
+ sorting, control = st.columns((11, 1), gap='large')
133
+ with sorting:
134
+ # st.write('## Sorting')
135
+ # st.write('Please drag the images to sort them.')
136
+ st.progress((st.session_state.counter[prompt_id] + 1) / batch_num, text=f"Batch {st.session_state.counter[prompt_id] + 1} / {batch_num}")
137
+ # st.write(items.iloc[self.batch_size*st.session_state.counter[prompt_id]: self.batch_size*(st.session_state.counter[prompt_id]+1)])
138
+
139
+ width, height = items.loc[0, 'size'].split('x')
140
+ if int(height) >= int(width):
141
+ self.draggable_images(items.iloc[self.batch_size*st.session_state.counter[prompt_id]: self.batch_size*(st.session_state.counter[prompt_id]+1)].reset_index(drop=True), prompt_id=prompt_id, layout='portrait')
142
+ else:
143
+ self.draggable_images(items.iloc[self.batch_size*st.session_state.counter[prompt_id]: self.batch_size*(st.session_state.counter[prompt_id]+1)].reset_index(drop=True), prompt_id=prompt_id, layout='landscape')
144
+ # st.write(str(st.session_state.ranking))
145
+ with control:
146
+ if st.session_state.counter[prompt_id] < batch_num - 1:
147
+ st.button(":arrow_right:", key='next', on_click=self.next_batch, help='Next Batch', kwargs={'prompt_id': prompt_id})
148
+ else:
149
+ st.button(":ballot_box_with_check:", key='finished', on_click=self.next_batch, help='Finished', kwargs={'prompt_id': prompt_id, 'progress': 'finished'})
150
+
151
+ if st.session_state.counter[prompt_id] > 0:
152
+ st.button(":arrow_left:", key='prev', on_click=self.prev_batch, help='Previous Batch', kwargs={'prompt_id': prompt_id})
153
+
154
+ elif st.session_state.progress[prompt_id] == 'finished':
155
+ st.write('## You have ranked all models for this tag!')
156
+ st.write('Thank you for your participation! Feel free to do the following things:')
157
+ st.write('* Rank for other tags and prompts.')
158
+ st.write('* Back to the gallery page to see more images.')
159
+ st.write('* Rank again for this tag and prompt.')
160
+ st.write('*More functions are coming soon... Please stay tuned*')
161
+
162
+ gallery_btn = st.button('🖼️ Back to Gallery')
163
+ if gallery_btn:
164
+ switch_page('gallery')
165
+
166
+ restart_btn = st.button('🎖️ Rank Again')
167
+ if restart_btn:
168
+ st.session_state.progress['prompt_id'] = 'ranking'
169
+ st.session_state.counter[prompt_id] = 0
170
+ st.experimental_rerun()
171
 
 
 
 
 
 
 
172
 
173
+ def next_batch(self, prompt_id, progress=None):
174
+
175
+ # save ranking to dataset
176
+ # print(st.session_state.ranking)
177
+ ranking_dataset = datasets.load_dataset('MAPS-research/GEMRec-Ranking', split='train')
178
+ for image_id in st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]].keys():
179
+ modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_id]['modelVersion_id'].values[0]
180
+ ranking = st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]][image_id]
181
+ # print({'image_id': image_id, 'modelVersion_id': modelVersion_id, 'ranking': ranking, "user_name": st.session_state.user_id[0], "timestamp": st.session_state.user_id[1]})
182
+ ranking_dataset = ranking_dataset.add_item({'image_id': image_id, 'modelVersion_id': modelVersion_id, 'ranking': ranking, "user_name": st.session_state.user_id[0], "timestamp": st.session_state.user_id[1]})
183
+ ranking_dataset.push_to_hub('MAPS-research/GEMRec-Ranking', split='train')
184
+
185
+ if progress == 'finished':
186
+ st.session_state.progress['prompt_id'] = 'finished'
187
+ else:
188
+ st.session_state.counter[prompt_id] += 1
189
+
190
+ def prev_batch(self, prompt_id):
191
+ st.session_state.counter[prompt_id] -= 1
192
 
193
 
194
  if __name__ == "__main__":
 
201
  switch_page("home")
202
 
203
  else:
204
+ has_selection = False
205
+
206
  for key, value in st.session_state.selected_dict.items():
207
  for v in value:
208
+ if v:
209
+ has_selection = True
210
+ break
211
 
212
+ if not has_selection:
213
  st.info('You have not checked any image yet. Please go back to the gallery page and check some images.')
214
  gallery_btn = st.button('Go to Gallery')
215
  if gallery_btn:
 
218
  # st.write('You have checked ' + str(len(selected_modelVersions)) + ' images.')
219
  roster, promptBook, images_ds = load_hf_dataset()
220
  print(st.session_state.selected_dict)
221
+ # st.write("# Full function is coming soon.")
222
+
223
+ # only select the part of the promptbook where tag is the same as st.session_state.selected_dict.keys(), while model version ids are the same as corresponding values to each key
224
+ promptBook_selected = pd.DataFrame()
225
+ for key, value in st.session_state.selected_dict.items():
226
+ promptBook_selected = promptBook_selected.append(promptBook[(promptBook['prompt_id'] == key) & (promptBook['modelVersion_id'].isin(value))])
227
+ promptBook_selected = promptBook_selected.reset_index(drop=True)
228
+ # st.write(promptBook_selected)
229
+ images_endpoint = "https://modelcofferbucket.s3-accelerate.amazonaws.com/"
230
+
231
+ app = RankingApp(promptBook_selected, images_endpoint, batch_size=4)
232
+ app.app()
 
 
 
 
 
233