Ricercar commited on
Commit
67b2625
β€’
1 Parent(s): 3e0a995

change summary visual

Browse files
Files changed (2) hide show
  1. Home.py +9 -9
  2. pages/Summary.py +69 -52
Home.py CHANGED
@@ -55,11 +55,11 @@ def info():
55
  with st.sidebar:
56
  st.write('## About')
57
  st.write(
58
- "**This is a web application for individual users to quickly dig out the most suitable text-to-image generation model from [civitai](https://civitai.com).** Our research aims to understand personal preference to images synthesized by generative models fine-tuned on stable diffusion and you can contribute by playing with this tool and giving us your feedback! "
59
  )
60
 
61
  st.write(
62
- "After picking images you liked from Gallery and a battle-mode Ranking Contest, a summary dashboard will be presented **indicating your preferred models with download links ready to be deployed in [Webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)** !"
63
  )
64
 
65
 
@@ -88,15 +88,15 @@ if __name__ == '__main__':
88
  st.write('### About GEMRec')
89
  st.write("**GE**nerative **M**odel **Rec**ommendation (**GEMRec**) is a research project by [MAPS Lab](https://github.com/MAPS-research), NYU Shanghai.")
90
  st.write('### Our Task')
91
- st.write('Given a user’s preference on a set of generated images, we aim to recommend the most preferred generative model for the user.')
92
  st.write('### Our Approach')
93
- st.write('We propose a two-stage framework, which contains prompt-model retrival and generated item ranking. :red[Your participation in this web application will help us to improve our framework and to further our research on personalization.]')
94
- st.write('### Key Contributions')
95
- st.write('1. We propose a two-stage framework to approach the Generative Model Recommendation problem. Our framework allows end-users to effectively explore a diverse set of generative models to understand their expressiveness. It also allows system developers to elicit user preferences for items generated from personalized prompts.')
96
- st.write('2. We release GEMRec-18K, a dense prompt-model interaction dataset that consists of 18K images generated by pairing 200 generative models with 90 prompts collected from real-world usages, accompanied by detailed metadata and generation configurations. This dataset builds the cornerstone for exploring Generative Recommendation and can be useful for other tasks related to understanding generative models')
97
- st.write('3. We take the first step in examining evaluation metrics for personalized image generations and identify several limitations in existing metrics. We propose a weighted metric that is more suitable for the task and opens up directions for future improvements in model training and evaluations.')
98
 
99
- with st.expander(label='**πŸ’» Where can I find the paper and dataset?**'):
100
  st.write('### Paper')
101
  st.write('Arxiv: [Towards Personalized Prompt-Model Retrieval for Generative Recommendation](https://arxiv.org/abs/2308.02205)')
102
  st.write('### GEMRec-18K Dataset')
 
55
  with st.sidebar:
56
  st.write('## About')
57
  st.write(
58
+ "This is a web application **for individual users to quickly dig out the most preferable text-to-image models from [civitai](https://civitai.com) for different prompts**. Our research aims to understand personal preference towards generative models and you can contribute by playing with this tool and giving us your feedback! "
59
  )
60
 
61
  st.write(
62
+ "After picking images you liked from Gallery and a Ranking Contest, a summary dashboard will be presented **indicating your preferred models with download links ready to be deployed in [Webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)** !"
63
  )
64
 
65
 
 
88
  st.write('### About GEMRec')
89
  st.write("**GE**nerative **M**odel **Rec**ommendation (**GEMRec**) is a research project by [MAPS Lab](https://github.com/MAPS-research), NYU Shanghai.")
90
  st.write('### Our Task')
91
+ st.write('Navigate hundreds of text-to-image models through various categories of pre-defined prompts and a graph-based interface. Given a user’s preference and interaction data, we aim to recommend the most preferred generative model for the user.')
92
  st.write('### Our Approach')
93
+ st.write('We propose a two-stage framework, which contains prompt-model retrieval and generative model ranking. :red[Your participation in this web application will help us to improve our framework and to further our research on personalization.]')
94
+ # st.write('### Key Contributions')
95
+ # st.write('1. We propose a two-stage framework to approach the Generative Model Recommendation problem. Our framework allows end-users to effectively explore a diverse set of generative models to understand their expressiveness. It also allows system developers to elicit user preferences for items generated from personalized prompts.')
96
+ # st.write('2. We release GEMRec-18K, a dense prompt-model interaction dataset that consists of 18K images generated by pairing 200 generative models with 90 prompts collected from real-world usages, accompanied by detailed metadata and generation configurations. This dataset builds the cornerstone for exploring Generative Recommendation and can be useful for other tasks related to understanding generative models')
97
+ # st.write('3. We take the first step in examining evaluation metrics for personalized image generations and identify several limitations in existing metrics. We propose a weighted metric that is more suitable for the task and opens up directions for future improvements in model training and evaluations.')
98
 
99
+ with st.expander(label='**πŸ“‘ Where can I find the paper and dataset?**'):
100
  st.write('### Paper')
101
  st.write('Arxiv: [Towards Personalized Prompt-Model Retrieval for Generative Recommendation](https://arxiv.org/abs/2308.02205)')
102
  st.write('### GEMRec-18K Dataset')
pages/Summary.py CHANGED
@@ -30,7 +30,7 @@ class DashboardApp:
30
 
31
  def sidebar(self, tags, mode):
32
  with st.sidebar:
33
- tag = st.selectbox('Select a tag', tags, key='tag')
34
  # st.write('---')
35
  with st.form('summary_sidebar_form'):
36
  st.write('## Want a more comprehensive summary?')
@@ -48,10 +48,10 @@ class DashboardApp:
48
  # if submit_feedback:
49
  # print(feedback)
50
 
51
- return tag
52
 
53
  def leaderboard(self, tag, db_table):
54
- tag = '%' if tag == 'all' else tag
55
 
56
  # get the ranking results of the current user
57
  curser = RANKING_CONN.cursor()
@@ -65,39 +65,40 @@ class DashboardApp:
65
  # sort the modelVersion_standings by value into a list of tuples in descending order
66
  st.session_state.modelVersion_standings[tag] = sorted(st.session_state.modelVersion_standings[tag].items(), key=lambda x: x[1], reverse=True)
67
 
68
- tab1, tab2 = st.tabs(['Top Picks', 'Detailed Info'])
69
 
70
- with tab1:
71
  # self.podium(modelVersion_standings)
72
- switch_stage = st.toggle('Manual Reorder', key='switch_stage')
73
-
74
- example_prompts = []
75
- # get example images
76
- for key, value in st.session_state.selected_dict.items():
77
- for model in st.session_state.modelVersion_standings[tag]:
78
- if model[0] in value:
79
- example_prompts.append(key)
80
-
81
- if switch_stage:
82
- self.podium_expander(tag, n=len(st.session_state.modelVersion_standings[tag]), summary_mode='edit', example_prompts=example_prompts)
83
- else:
84
- self.podium_expander(tag, n=len(st.session_state.modelVersion_standings[tag]), summary_mode='display', example_prompts=example_prompts)
85
- # if st.session_state.summary_mode == 'display':
86
- # switch_stage = st.button('Manual Reorder', key='switch_stage_edit', on_click=lambda: st.session_state.__setitem__('summary_mode', 'edit'))
87
- # self.podium_expander(tag, n=3, summary_mode='display')
88
- #
89
- # elif st.session_state.summary_mode == 'edit':
90
- # switch_stage = st.button('Done', key='switch_stage_done', type='primary', on_click=lambda: st.session_state.__setitem__('summary_mode', 'display'))
91
- # self.podium_expander(tag, n=len(st.session_state.modelVersion_standings[tag]), summary_mode='edit')
92
-
93
- with tab2:
94
- st.write('**Detailed information of all selected models**')
95
- detailed_info = pd.merge(pd.DataFrame(st.session_state.modelVersion_standings[tag], columns=['modelVersion_id', 'ranking_score']), self.roster, on='modelVersion_id')
96
-
97
- detailed_info = detailed_info[['model_name', 'modelVersion_name', 'model_download_count', 'tag', 'baseModel']]
98
-
99
- st.data_editor(detailed_info, hide_index=False, disabled=True)
100
- st.caption('You can click the header to sort the table by that column.')
 
101
 
102
  def podium_expander(self, tag, example_prompts, n=3, summary_mode: ['display', 'edit'] = 'display'):
103
 
@@ -110,29 +111,40 @@ class DashboardApp:
110
  icon = 'πŸ₯‡'if i == 0 else 'πŸ₯ˆ' if i == 1 else 'πŸ₯‰' if i == 2 else '🎈'
111
  podium_display = st.columns([1, 14], gap='medium')
112
  with podium_display[0]:
113
- if summary_mode == 'display':
114
- st.title(f'{icon}')
115
- elif summary_mode == 'edit':
116
- settop = st.button('πŸ”', key=f'settop_{modelVersion_id}', help='Set this model to the top', disabled=i == 0, on_click=self.switch_order, args=(tag, i, 0), use_container_width=True)
117
- moveup = st.button('⬆', key=f'moveup_{modelVersion_id}', help='Move this model up', disabled=i == 0, on_click=self.switch_order, args=(tag, i, i - 1), use_container_width=True)
118
- movedown = st.button('⬇', key=f'movedown_{modelVersion_id}', help='Move this model down', disabled=i == n - 1, on_click=self.switch_order, args=(tag, i, i + 1), use_container_width=True)
 
119
  with podium_display[1]:
120
- title_display = st.columns([3, 1, 1, 1])
121
  with title_display[0]:
122
  st.write(f'##### {model_name}, {modelVersion_name}')
123
  # st.write(f'Ranking Score: {winning_times}')
124
  with title_display[1]:
125
  # image_display = st.selectbox('image display', ['Featured', 'All Images'], key=f'image_display_{modelVersion_id}', label_visibility='collapsed')
126
- image_display = st.checkbox('Show all images', key=f'image_display_{modelVersion_id}')
127
 
128
  with title_display[2]:
129
- st.link_button('Download Model', url, use_container_width=True)
130
  with title_display[3]:
131
- st.link_button('Civitai Page', f'https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id}', use_container_width=True, type='primary')
132
  # st.write(f'[Civitai Page](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id}), [Model Download Link]({url}), Ranking Score: {winning_times}')
133
  # with st.expander(f'**{icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id})**, Ranking Score: {winning_times}'):
134
-
135
-
 
 
 
 
 
 
 
 
 
 
136
 
137
  if not image_display:
138
  example_images = self.promptBook[self.promptBook['prompt_id'].isin(example_prompts) & (self.promptBook['modelVersion_id']==modelVersion_id)]['image_id'].values
@@ -143,11 +155,10 @@ class DashboardApp:
143
  )
144
 
145
  else:
146
- st.toast('🐌 It may take a while to load all images. Please be patient.')
147
  # with st.expander(f'Show Images'):
148
  images = self.promptBook[self.promptBook['modelVersion_id'] == modelVersion_id]['image_id'].values
149
 
150
- safety_check = st.checkbox('Include potentially unsafe or offensive images', value=False, key=modelVersion_id)
151
  unsafe_prompts = json.load(open('data/unsafe_prompts.json', 'r'))
152
  # merge dict values into one list
153
  unsafe_prompts = [item for sublist in unsafe_prompts.values() for item in sublist]
@@ -162,6 +173,7 @@ class DashboardApp:
162
  images,
163
  img_style={"margin": "5px", "height": "100px"}
164
  )
 
165
 
166
  # # st.write(f'### Images generated with {icon} {model_name}, {modelVersion_name}')
167
  # col_num = 4
@@ -212,7 +224,7 @@ class DashboardApp:
212
  # get tags from database of the current user
213
  db_table = 'sort_results' if mode == 'Drag and Sort' else 'battle_results'
214
 
215
- tags = ['all']
216
  curser = RANKING_CONN.cursor()
217
  curser.execute(
218
  f"SELECT DISTINCT tag FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}'")
@@ -220,11 +232,13 @@ class DashboardApp:
220
  tags.append(row['tag'])
221
  curser.close()
222
 
223
- if tags == ['all']:
224
  st.info(f'No rankings are finished with {mode} mode yet.')
225
 
226
  else:
227
- tag = self.sidebar(tags, mode)
 
 
228
  self.leaderboard(tag, db_table)
229
 
230
  with st.sidebar:
@@ -240,7 +254,7 @@ class DashboardApp:
240
  RANKING_CONN.commit()
241
  curser.close()
242
 
243
- st.toast('Thanks for your feedback! We will take it into consideration in our future work.')
244
 
245
 
246
  if __name__ == "__main__":
@@ -280,4 +294,7 @@ if __name__ == "__main__":
280
  app = DashboardApp(roster, promptBook, session_finished)
281
  app.app()
282
 
 
 
 
283
 
 
30
 
31
  def sidebar(self, tags, mode):
32
  with st.sidebar:
33
+ # tag = st.selectbox('Select a tag', tags, key='tag')
34
  # st.write('---')
35
  with st.form('summary_sidebar_form'):
36
  st.write('## Want a more comprehensive summary?')
 
48
  # if submit_feedback:
49
  # print(feedback)
50
 
51
+ # return tag
52
 
53
  def leaderboard(self, tag, db_table):
54
+ tag = '%' if tag == 'overview' else tag
55
 
56
  # get the ranking results of the current user
57
  curser = RANKING_CONN.cursor()
 
65
  # sort the modelVersion_standings by value into a list of tuples in descending order
66
  st.session_state.modelVersion_standings[tag] = sorted(st.session_state.modelVersion_standings[tag].items(), key=lambda x: x[1], reverse=True)
67
 
68
+ # tab1, tab2 = st.tabs(['Top Picks', 'Detailed Info'])
69
 
70
+ # with tab1:
71
  # self.podium(modelVersion_standings)
72
+ # switch_stage = st.toggle('Manual Reorder', key='switch_stage')
73
+
74
+ example_prompts = []
75
+ # get example images
76
+ for key, value in st.session_state.selected_dict.items():
77
+ for model in st.session_state.modelVersion_standings[tag]:
78
+ if model[0] in value:
79
+ example_prompts.append(key)
80
+
81
+ # if switch_stage:
82
+ # self.podium_expander(tag, n=len(st.session_state.modelVersion_standings[tag]), summary_mode='edit', example_prompts=example_prompts)
83
+ # else:
84
+ self.podium_expander(tag, n=len(st.session_state.modelVersion_standings[tag]), summary_mode='display', example_prompts=example_prompts)
85
+ # if st.session_state.summary_mode == 'display':
86
+ # switch_stage = st.button('Manual Reorder', key='switch_stage_edit', on_click=lambda: st.session_state.__setitem__('summary_mode', 'edit'))
87
+ # self.podium_expander(tag, n=3, summary_mode='display')
88
+ #
89
+ # elif st.session_state.summary_mode == 'edit':
90
+ # switch_stage = st.button('Done', key='switch_stage_done', type='primary', on_click=lambda: st.session_state.__setitem__('summary_mode', 'display'))
91
+ # self.podium_expander(tag, n=len(st.session_state.modelVersion_standings[tag]), summary_mode='edit')
92
+
93
+ # with tab2:
94
+ st.write('---')
95
+ st.write('**Detailed information of all selected models**')
96
+ detailed_info = pd.merge(pd.DataFrame(st.session_state.modelVersion_standings[tag], columns=['modelVersion_id', 'ranking_score']), self.roster, on='modelVersion_id')
97
+
98
+ detailed_info = detailed_info[['model_name', 'modelVersion_name', 'model_download_count', 'tag', 'baseModel']]
99
+
100
+ st.data_editor(detailed_info, hide_index=False, disabled=True)
101
+ st.caption('You can click the header to sort the table by that column.')
102
 
103
  def podium_expander(self, tag, example_prompts, n=3, summary_mode: ['display', 'edit'] = 'display'):
104
 
 
111
  icon = 'πŸ₯‡'if i == 0 else 'πŸ₯ˆ' if i == 1 else 'πŸ₯‰' if i == 2 else '🎈'
112
  podium_display = st.columns([1, 14], gap='medium')
113
  with podium_display[0]:
114
+ st.title(f'{icon}')
115
+ # if summary_mode == 'display':
116
+ # st.title(f'{icon}')
117
+ # elif summary_mode == 'edit':
118
+ # settop = st.button('πŸ”', key=f'settop_{modelVersion_id}', help='Set this model to the top', disabled=i == 0, on_click=self.switch_order, args=(tag, i, 0), use_container_width=True)
119
+ # moveup = st.button('⬆', key=f'moveup_{modelVersion_id}', help='Move this model up', disabled=i == 0, on_click=self.switch_order, args=(tag, i, i - 1), use_container_width=True)
120
+ # movedown = st.button('⬇', key=f'movedown_{modelVersion_id}', help='Move this model down', disabled=i == n - 1, on_click=self.switch_order, args=(tag, i, i + 1), use_container_width=True)
121
  with podium_display[1]:
122
+ title_display = st.columns([3.5, 2, 2, 2, 0.5, 0.5, 0.5])
123
  with title_display[0]:
124
  st.write(f'##### {model_name}, {modelVersion_name}')
125
  # st.write(f'Ranking Score: {winning_times}')
126
  with title_display[1]:
127
  # image_display = st.selectbox('image display', ['Featured', 'All Images'], key=f'image_display_{modelVersion_id}', label_visibility='collapsed')
128
+ image_display = st.toggle('Show all images', key=f'image_display_{modelVersion_id}')
129
 
130
  with title_display[2]:
131
+ st.link_button('Download', url, use_container_width=True)
132
  with title_display[3]:
133
+ st.link_button('Civitai', f'https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id}', use_container_width=True, type='primary')
134
  # st.write(f'[Civitai Page](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id}), [Model Download Link]({url}), Ranking Score: {winning_times}')
135
  # with st.expander(f'**{icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id})**, Ranking Score: {winning_times}'):
136
+ with title_display[4]:
137
+ settop = st.button('πŸ”', key=f'settop_{modelVersion_id}', help='Set this model to the top',
138
+ disabled=i == 0, on_click=self.switch_order, args=(tag, i, 0),
139
+ use_container_width=True)
140
+ with title_display[5]:
141
+ moveup = st.button('⬆', key=f'moveup_{modelVersion_id}', help='Move this model up',
142
+ disabled=i == 0, on_click=self.switch_order, args=(tag, i, i - 1),
143
+ use_container_width=True)
144
+ with title_display[6]:
145
+ movedown = st.button('⬇', key=f'movedown_{modelVersion_id}', help='Move this model down',
146
+ disabled=i == n - 1, on_click=self.switch_order, args=(tag, i, i + 1),
147
+ use_container_width=True)
148
 
149
  if not image_display:
150
  example_images = self.promptBook[self.promptBook['prompt_id'].isin(example_prompts) & (self.promptBook['modelVersion_id']==modelVersion_id)]['image_id'].values
 
155
  )
156
 
157
  else:
 
158
  # with st.expander(f'Show Images'):
159
  images = self.promptBook[self.promptBook['modelVersion_id'] == modelVersion_id]['image_id'].values
160
 
161
+ safety_check = st.toggle('Include potentially unsafe or offensive images', value=False, key=modelVersion_id)
162
  unsafe_prompts = json.load(open('data/unsafe_prompts.json', 'r'))
163
  # merge dict values into one list
164
  unsafe_prompts = [item for sublist in unsafe_prompts.values() for item in sublist]
 
173
  images,
174
  img_style={"margin": "5px", "height": "100px"}
175
  )
176
+ st.write('🐌 It may take a while to load all images. Please be patient, and **NEVER USE THE REFRESH BUTTON ON YOUR BROWSER**.')
177
 
178
  # # st.write(f'### Images generated with {icon} {model_name}, {modelVersion_name}')
179
  # col_num = 4
 
224
  # get tags from database of the current user
225
  db_table = 'sort_results' if mode == 'Drag and Sort' else 'battle_results'
226
 
227
+ tags = ['overview']
228
  curser = RANKING_CONN.cursor()
229
  curser.execute(
230
  f"SELECT DISTINCT tag FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}'")
 
232
  tags.append(row['tag'])
233
  curser.close()
234
 
235
+ if tags == ['overview']:
236
  st.info(f'No rankings are finished with {mode} mode yet.')
237
 
238
  else:
239
+ tags = tags[0:1] if len(tags) == 2 else tags
240
+ tag = st.radio('Select a tag', tags, index=0, horizontal=True, label_visibility='collapsed')
241
+ self.sidebar(tags, mode)
242
  self.leaderboard(tag, db_table)
243
 
244
  with st.sidebar:
 
254
  RANKING_CONN.commit()
255
  curser.close()
256
 
257
+ st.toast('πŸ™ **Thanks for your feedback! We will take it into consideration in our future work.**')
258
 
259
 
260
  if __name__ == "__main__":
 
294
  app = DashboardApp(roster, promptBook, session_finished)
295
  app.app()
296
 
297
+ with open('./css/style.css') as f:
298
+ st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
299
+
300