Ricercar commited on
Commit
0cf0993
1 Parent(s): bac893c

gallery 2.0

Browse files

graph view added

Files changed (1) hide show
  1. pages/Gallery.py +165 -43
pages/Gallery.py CHANGED
@@ -9,6 +9,7 @@ import streamlit as st
9
  from bs4 import BeautifulSoup
10
  from datasets import load_dataset, Dataset, load_from_disk
11
  from huggingface_hub import login
 
12
  from streamlit_extras.switch_page_button import switch_page
13
  from sklearn.svm import LinearSVC
14
 
@@ -50,6 +51,55 @@ class GalleryApp:
50
  for key in info:
51
  st.write(f"**{key}**: {items.iloc[idx + j][key]}")
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def selection_panel(self, items):
54
  # temperal function
55
 
@@ -170,6 +220,8 @@ class GalleryApp:
170
 
171
  selected_prompt = st.selectbox('Select prompt', prompts)
172
 
 
 
173
  items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
174
  prompt_id = items['prompt_id'].unique()[0]
175
  note = items['note'].unique()[0]
@@ -206,14 +258,14 @@ class GalleryApp:
206
  except:
207
  pass
208
 
209
- return prompt_tags, tag, prompt_id, items
210
 
211
  def app(self):
212
  st.title('Model Visualization and Retrieval')
213
  st.write('This is a gallery of images generated by the models')
214
 
215
- prompt_tags, tag, prompt_id, items = self.sidebar()
216
- items, info, col_num = self.selection_panel(items)
217
 
218
  # add safety check for some prompts
219
  safety_check = True
@@ -230,57 +282,115 @@ class GalleryApp:
230
  safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
231
 
232
  if safety_check:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
- if 'selected_dict' in st.session_state:
235
- # st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
236
- dynamic_weight_options = ['Grid Search', 'SVM', 'Greedy']
237
- dynamic_weight_panel = st.columns(len(dynamic_weight_options))
238
-
239
- if len(st.session_state.selected_dict.get(prompt_id, [])) > 0:
240
- btn_disable = False
241
  else:
242
- btn_disable = True
 
 
 
 
 
 
 
 
 
 
243
 
244
- for i in range(len(dynamic_weight_options)):
245
- method = dynamic_weight_options[i]
246
- with dynamic_weight_panel[i]:
247
- btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items, method))
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
- prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}", disabled=False, key=f'{prompt_id}')
250
- if prompt:
251
- switch_page("ranking")
 
252
 
253
- with st.form(key=f'{prompt_id}'):
254
- # buttons = st.columns([1, 1, 1])
255
- buttons_space = st.columns([1, 1, 1, 1])
256
- gallery_space = st.empty()
257
 
258
- with buttons_space[0]:
259
- continue_btn = st.form_submit_button('Confirm Selection', use_container_width=True, type='primary')
260
- if continue_btn:
261
- self.submit_actions('Continue', prompt_id)
262
 
263
- with buttons_space[1]:
264
- select_btn = st.form_submit_button('Select All', use_container_width=True)
265
- if select_btn:
266
- self.submit_actions('Select', prompt_id)
267
 
268
- with buttons_space[2]:
269
- deselect_btn = st.form_submit_button('Deselect All', use_container_width=True)
270
- if deselect_btn:
271
- self.submit_actions('Deselect', prompt_id)
272
 
273
- with buttons_space[3]:
274
- refresh_btn = st.form_submit_button('Refresh', on_click=gallery_space.empty, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
275
 
276
- with gallery_space.container():
277
- with st.spinner('Loading images...'):
278
- self.gallery_standard(items, col_num, info)
279
 
280
- st.info("Don't forget to scroll back to top and click the 'Confirm Selection' button to save your selection!!!")
281
- # prompt = st.chat_input(f"checked: {str(st.session_state.selected_dict.get(prompt_id, []))}", disabled=False, key=f'{prompt_id}')
282
- # if prompt:
283
- # switch_page("ranking")
284
 
285
  def submit_actions(self, status, prompt_id):
286
  # remove counter from session state
@@ -429,6 +539,18 @@ def load_hf_dataset():
429
 
430
  return roster, promptBook, images_ds
431
 
 
 
 
 
 
 
 
 
 
 
 
 
432
 
433
  if __name__ == "__main__":
434
  st.set_page_config(page_title="Model Coffer Gallery", page_icon="🖼️", layout="wide")
 
9
  from bs4 import BeautifulSoup
10
  from datasets import load_dataset, Dataset, load_from_disk
11
  from huggingface_hub import login
12
+ from streamlit_agraph import agraph, Node, Edge, Config
13
  from streamlit_extras.switch_page_button import switch_page
14
  from sklearn.svm import LinearSVC
15
 
 
51
  for key in info:
52
  st.write(f"**{key}**: {items.iloc[idx + j][key]}")
53
 
54
+ def gallery_graph(self, items):
55
+ items = load_tsne_coordinates(items)
56
+
57
+ scale = 50
58
+ items.loc[:, 'x'] = items['x'] * scale
59
+ items.loc[:, 'y'] = items['y'] * scale
60
+
61
+ nodes = []
62
+ edges = []
63
+
64
+ for idx in items.index:
65
+ # if items.loc[idx, 'modelVersion_id'] in st.session_state.selected_dict.get(items.loc[idx, 'prompt_id'], 0):
66
+ # opacity = 0.2
67
+ # else:
68
+ # opacity = 1.0
69
+
70
+ nodes.append(Node(id=items.loc[idx, 'image_id'],
71
+ # label=str(items.loc[idx, 'model_name']),
72
+ size=20,
73
+ shape='image',
74
+ image=f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{items.loc[idx, 'image_id']}.png",
75
+ x=items.loc[idx, 'x'].item(),
76
+ y=items.loc[idx, 'y'].item(),
77
+ fixed=True,
78
+ color={'background': '#00000', 'border': '#ffffff'},
79
+ # opacity=opacity,
80
+ shadow={'enabled': True, 'color': 'rgba(0,0,0,0.4)', 'size': 10, 'x': 1, 'y': 1},
81
+ # borderWidth=1,
82
+ # shapeProperties={'useBorderWithImage': True},
83
+ )
84
+ )
85
+
86
+ config = Config(width='100%',
87
+ height=800,
88
+ directed=True,
89
+ physics=False,
90
+ hierarchical=False,
91
+ # **kwargs
92
+ )
93
+
94
+ return agraph(nodes=nodes,
95
+ edges=edges,
96
+ config=config
97
+ )
98
+
99
+
100
+
101
+
102
+
103
  def selection_panel(self, items):
104
  # temperal function
105
 
 
220
 
221
  selected_prompt = st.selectbox('Select prompt', prompts)
222
 
223
+ mode = st.radio('Select a mode', ['Gallery', 'Graph'], horizontal=True, index=1)
224
+
225
  items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
226
  prompt_id = items['prompt_id'].unique()[0]
227
  note = items['note'].unique()[0]
 
258
  except:
259
  pass
260
 
261
+ return prompt_tags, tag, prompt_id, items, mode
262
 
263
  def app(self):
264
  st.title('Model Visualization and Retrieval')
265
  st.write('This is a gallery of images generated by the models')
266
 
267
+ prompt_tags, tag, prompt_id, items, mode = self.sidebar()
268
+ # items, info, col_num = self.selection_panel(items)
269
 
270
  # add safety check for some prompts
271
  safety_check = True
 
282
  safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
283
 
284
  if safety_check:
285
+ if mode == 'Gallery':
286
+ self.gallery_mode(prompt_id, items)
287
+ elif mode == 'Graph':
288
+ self.graph_mode(prompt_id, items)
289
+
290
+
291
+ def graph_mode(self, prompt_id, items):
292
+ graph_cols = st.columns([3, 1])
293
+ prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}",
294
+ disabled=False, key=f'{prompt_id}')
295
+ if prompt:
296
+ switch_page("ranking")
297
+
298
+ with graph_cols[0]:
299
+ return_value = self.gallery_graph(items)
300
+ with graph_cols[1]:
301
+ if return_value:
302
+ image_url = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{return_value}.png"
303
+
304
+ st.image(image_url)
305
+
306
+ item = items[items['image_id'] == return_value].reset_index(drop=True).iloc[0]
307
+ modelVersion_id = item['modelVersion_id']
308
+
309
+ # handle selection
310
+ if 'selected_dict' in st.session_state:
311
+ if item['prompt_id'] not in st.session_state.selected_dict:
312
+ st.session_state.selected_dict[item['prompt_id']] = []
313
+
314
+ if modelVersion_id in st.session_state.selected_dict[item['prompt_id']]:
315
+ checked = True
316
+ else:
317
+ checked = False
318
+
319
+ if checked:
320
+ deselect = st.button('Deselect', key=f'select_{item["prompt_id"]}_{item["modelVersion_id"]}', use_container_width=True)
321
+ if deselect:
322
+ st.session_state.selected_dict[item['prompt_id']].remove(item['modelVersion_id'])
323
+ st.experimental_rerun()
324
 
 
 
 
 
 
 
 
325
  else:
326
+ select = st.button('Select', key=f'select_{item["prompt_id"]}_{item["modelVersion_id"]}', use_container_width=True, type='primary')
327
+ if select:
328
+ st.session_state.selected_dict[item['prompt_id']].append(item['modelVersion_id'])
329
+ st.experimental_rerun()
330
+
331
+ # st.write(item)
332
+ infos = ['model_name', 'modelVersion_name', 'model_download_count', 'clip_score', 'mcos_score',
333
+ 'nsfw_score']
334
+ for info in infos:
335
+ st.write(f"**{info}**:")
336
+ st.write(item[info])
337
 
338
+ else:
339
+ st.info('Please click on an image to show')
340
+
341
+
342
+ def gallery_mode(self, prompt_id, items):
343
+ items, info, col_num = self.selection_panel(items)
344
+
345
+ if 'selected_dict' in st.session_state:
346
+ # st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
347
+ dynamic_weight_options = ['Grid Search', 'SVM', 'Greedy']
348
+ dynamic_weight_panel = st.columns(len(dynamic_weight_options))
349
+
350
+ if len(st.session_state.selected_dict.get(prompt_id, [])) > 0:
351
+ btn_disable = False
352
+ else:
353
+ btn_disable = True
354
 
355
+ for i in range(len(dynamic_weight_options)):
356
+ method = dynamic_weight_options[i]
357
+ with dynamic_weight_panel[i]:
358
+ btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items, method))
359
 
360
+ prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}", disabled=False, key=f'{prompt_id}')
361
+ if prompt:
362
+ switch_page("ranking")
 
363
 
364
+ with st.form(key=f'{prompt_id}'):
365
+ # buttons = st.columns([1, 1, 1])
366
+ buttons_space = st.columns([1, 1, 1, 1])
367
+ gallery_space = st.empty()
368
 
369
+ with buttons_space[0]:
370
+ continue_btn = st.form_submit_button('Confirm Selection', use_container_width=True, type='primary')
371
+ if continue_btn:
372
+ self.submit_actions('Continue', prompt_id)
373
 
374
+ with buttons_space[1]:
375
+ select_btn = st.form_submit_button('Select All', use_container_width=True)
376
+ if select_btn:
377
+ self.submit_actions('Select', prompt_id)
378
 
379
+ with buttons_space[2]:
380
+ deselect_btn = st.form_submit_button('Deselect All', use_container_width=True)
381
+ if deselect_btn:
382
+ self.submit_actions('Deselect', prompt_id)
383
+
384
+ with buttons_space[3]:
385
+ refresh_btn = st.form_submit_button('Refresh', on_click=gallery_space.empty, use_container_width=True)
386
+
387
+ with gallery_space.container():
388
+ with st.spinner('Loading images...'):
389
+ self.gallery_standard(items, col_num, info)
390
+
391
+ st.info("Don't forget to scroll back to top and click the 'Confirm Selection' button to save your selection!!!")
392
 
 
 
 
393
 
 
 
 
 
394
 
395
  def submit_actions(self, status, prompt_id):
396
  # remove counter from session state
 
539
 
540
  return roster, promptBook, images_ds
541
 
542
+ @st.cache_data
543
+ def load_tsne_coordinates(items):
544
+ # load tsne coordinates
545
+ tsne_df = pd.read_parquet('./data/feats_tsne.parquet')
546
+
547
+ # print(tsne_df['modelVersion_id'].dtype)
548
+
549
+ print('before merge:', items)
550
+ items = items.merge(tsne_df, on=['modelVersion_id', 'prompt_id'], how='left')
551
+ print('after merge:', items)
552
+ return items
553
+
554
 
555
  if __name__ == "__main__":
556
  st.set_page_config(page_title="Model Coffer Gallery", page_icon="🖼️", layout="wide")