Ricercar commited on
Commit
257e746
1 Parent(s): dcf0c90

add set top methods

Browse files
Files changed (1) hide show
  1. pages/Gallery.py +24 -24
pages/Gallery.py CHANGED
@@ -272,21 +272,20 @@ class GalleryApp:
272
  if safety_check:
273
  items, info, col_num, preprocessor = self.selection_panel(items)
274
 
275
- # method = st.radio('Select a method to set dynamic weight', ['Grid Search', 'SVM', 'Greedy', 'Disable dynamic weight'], index=0, horizontal=True)
276
- #
277
- # if method != 'Disable dynamic weight':
278
- # if len(st.session_state.selected_dict[prompt_id]) > 0:
279
- # selected = items[
280
- # items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(
281
- # drop=True)
282
- # st.session_state.score_weights[0: 3] = self.dynamic_weight(selected, items, preprocessor,
283
- # method=method)
284
- # # st.experimental_rerun()
285
- #
286
- # else:
287
- # print('no selected models')
288
- #
289
- # st.write(st.session_state.selected_dict.get(prompt_id, []))
290
 
291
  with st.form(key=f'{prompt_id}'):
292
  # buttons = st.columns([1, 1, 1])
@@ -315,9 +314,6 @@ class GalleryApp:
315
  with st.spinner('Loading images...'):
316
  self.gallery_standard(items, col_num, info)
317
 
318
- with st.sidebar:
319
- st.write(str(st.session_state.selected_dict[prompt_id]))
320
-
321
  def submit_actions(self, status, prompt_id):
322
  if status == 'Select':
323
  modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
@@ -329,7 +325,6 @@ class GalleryApp:
329
  print(st.session_state.selected_dict, 'deselect')
330
  st.experimental_rerun()
331
  # self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
332
- pass
333
  elif status == 'Continue':
334
  st.session_state.selected_dict[prompt_id] = []
335
  for key in st.session_state:
@@ -339,10 +334,13 @@ class GalleryApp:
339
  st.session_state.selected_dict[prompt_id].append(int(keys[2]))
340
  # switch_page("ranking")
341
  print(st.session_state.selected_dict, 'continue')
342
- pass
343
 
344
- def dynamic_weight(self, selected, items, preprocessor='crop', method='Grid Search'):
 
 
345
  optimal_weight = [0, 0, 0]
 
346
  if method == 'Grid Search':
347
  # grid search method
348
  top_ranking = len(items) * len(selected)
@@ -350,9 +348,10 @@ class GalleryApp:
350
  for clip_weight in np.arange(-1, 1, 0.1):
351
  for mcos_weight in np.arange(-1, 1, 0.1):
352
  for pop_weight in np.arange(-1, 1, 0.1):
353
- weight_all = clip_weight*items[f'norm_clip_{preprocessor}'] + mcos_weight*items[f'norm_mcos_{preprocessor}'] + pop_weight*items['norm_pop']
354
- weight_all_sorted = weight_all.sort_values(ascending=False)
355
 
 
 
 
356
  weight_selected = clip_weight*selected[f'norm_clip_{preprocessor}'] + mcos_weight*selected[f'norm_mcos_{preprocessor}'] + pop_weight*selected['norm_pop']
357
 
358
  # get the index of values of weight_selected in weight_all_sorted
@@ -361,6 +360,7 @@ class GalleryApp:
361
  rankings.append(weight_all_sorted.index[weight_all_sorted == weight].tolist()[0])
362
  if sum(rankings) <= top_ranking:
363
  top_ranking = sum(rankings)
 
364
  optimal_weight = [clip_weight, mcos_weight, pop_weight]
365
  print('optimal weight:', optimal_weight)
366
 
@@ -401,7 +401,7 @@ class GalleryApp:
401
  optimal_weight = [round(weight/len(selected), 2) for weight in optimal_weight]
402
  print('optimal weight:', optimal_weight)
403
 
404
- return optimal_weight
405
 
406
 
407
 
 
272
  if safety_check:
273
  items, info, col_num, preprocessor = self.selection_panel(items)
274
 
275
+ if 'selected_dict' in st.session_state:
276
+ st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
277
+ dynamic_weight_options = ['Grid Search', 'SVM', 'Greedy']
278
+ dynamic_weight_panel = st.columns(len(dynamic_weight_options))
279
+
280
+ if len(st.session_state.selected_dict.get(prompt_id, [])) > 0:
281
+ btn_disable = False
282
+ else:
283
+ btn_disable = True
284
+
285
+ for i in range(len(dynamic_weight_options)):
286
+ method = dynamic_weight_options[i]
287
+ with dynamic_weight_panel[i]:
288
+ btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items, preprocessor, method))
 
289
 
290
  with st.form(key=f'{prompt_id}'):
291
  # buttons = st.columns([1, 1, 1])
 
314
  with st.spinner('Loading images...'):
315
  self.gallery_standard(items, col_num, info)
316
 
 
 
 
317
  def submit_actions(self, status, prompt_id):
318
  if status == 'Select':
319
  modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
 
325
  print(st.session_state.selected_dict, 'deselect')
326
  st.experimental_rerun()
327
  # self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
 
328
  elif status == 'Continue':
329
  st.session_state.selected_dict[prompt_id] = []
330
  for key in st.session_state:
 
334
  st.session_state.selected_dict[prompt_id].append(int(keys[2]))
335
  # switch_page("ranking")
336
  print(st.session_state.selected_dict, 'continue')
337
+ st.experimental_rerun()
338
 
339
+ def dynamic_weight(self, prompt_id, items, preprocessor='crop', method='Grid Search'):
340
+ selected = items[
341
+ items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
342
  optimal_weight = [0, 0, 0]
343
+
344
  if method == 'Grid Search':
345
  # grid search method
346
  top_ranking = len(items) * len(selected)
 
348
  for clip_weight in np.arange(-1, 1, 0.1):
349
  for mcos_weight in np.arange(-1, 1, 0.1):
350
  for pop_weight in np.arange(-1, 1, 0.1):
 
 
351
 
352
+ weight_all = clip_weight*items[f'norm_clip_{preprocessor}'] + mcos_weight*items[f'norm_mcos_{preprocessor}'] + pop_weight*items['norm_pop']
353
+ weight_all_sorted = weight_all.sort_values(ascending=False).reset_index(drop=True)
354
+ # print('weight_all_sorted:', weight_all_sorted)
355
  weight_selected = clip_weight*selected[f'norm_clip_{preprocessor}'] + mcos_weight*selected[f'norm_mcos_{preprocessor}'] + pop_weight*selected['norm_pop']
356
 
357
  # get the index of values of weight_selected in weight_all_sorted
 
360
  rankings.append(weight_all_sorted.index[weight_all_sorted == weight].tolist()[0])
361
  if sum(rankings) <= top_ranking:
362
  top_ranking = sum(rankings)
363
+ print('current top ranking:', top_ranking, rankings)
364
  optimal_weight = [clip_weight, mcos_weight, pop_weight]
365
  print('optimal weight:', optimal_weight)
366
 
 
401
  optimal_weight = [round(weight/len(selected), 2) for weight in optimal_weight]
402
  print('optimal weight:', optimal_weight)
403
 
404
+ st.session_state.score_weights[0: 3] = optimal_weight
405
 
406
 
407