Ricercar commited on
Commit
319290c
1 Parent(s): 5bfac4b

new data cache method!

Browse files
Files changed (2) hide show
  1. app.py +103 -45
  2. requirements.txt +2 -1
app.py CHANGED
@@ -3,7 +3,6 @@ import numpy as np
3
  import random
4
  import pandas as pd
5
  import glob
6
- import csv
7
  from PIL import Image
8
  import datasets
9
  from datasets import load_dataset, Dataset, load_from_disk
@@ -13,13 +12,28 @@ import requests
13
  from bs4 import BeautifulSoup
14
  import re
15
 
 
 
 
16
  SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'avg_rank', 'pop': 'model_download_count'}
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class GalleryApp:
20
- def __init__(self, promptBook):
21
  self.promptBook = promptBook
22
- st.set_page_config(layout="wide")
23
 
24
  def gallery_masonry(self, items, col_num, info):
25
  cols = st.columns(col_num)
@@ -27,7 +41,7 @@ class GalleryApp:
27
  # items = items.sort_values(by=['brisque'], ascending=True).reset_index(drop=True)
28
  for idx in range(len(items)):
29
  with cols[idx % col_num]:
30
- image = st.session_state.images[items.iloc[idx]['row_idx'].item()]['image']
31
  st.image(image,
32
  use_column_width=True,
33
  )
@@ -58,7 +72,7 @@ class GalleryApp:
58
  if idx + j < len(items):
59
  with cols[j]:
60
  # show image
61
- image = st.session_state.images[items.iloc[idx+j]['row_idx'].item()]['image']
62
 
63
  st.image(image,
64
  use_column_width=True,
@@ -184,11 +198,12 @@ class GalleryApp:
184
  with sub_selecters[2]:
185
  pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=1.0, step=0.1)
186
 
187
- items.loc[:, 'weighted_score_sum'] = items['norm_clip'] * clip_weight + items['avg_rank'] * rank_weight + items[
188
- 'norm_pop'] * pop_weight
189
 
190
  continue_idx = 3
191
 
 
192
  with sub_selecters[continue_idx]:
193
  order = st.selectbox('Order', ['Ascending', 'Descending'], index=1 if sort_type == 'Scores' else 0)
194
  if order == 'Ascending':
@@ -211,6 +226,15 @@ class GalleryApp:
211
  items = items[items['checked'] == True].reset_index(drop=True)
212
  print(items)
213
 
 
 
 
 
 
 
 
 
 
214
  info = st.multiselect('Show Info',
215
  ['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
216
  'modelVersion_name', 'modelVersion_id', 'clip+rank', 'clip+pop', 'rank+pop',
@@ -303,6 +327,7 @@ class GalleryApp:
303
 
304
  if safety_check:
305
  items, info, col_num = self.selection_panel_2(items)
 
306
  # self.gallery_standard(items, col_num, info)
307
 
308
  with st.form(key=f'{prompt_id}', clear_on_submit=False):
@@ -340,44 +365,77 @@ class GalleryApp:
340
  dataset.push_to_hub('NYUSHPRP/ModelCofferMetadata', split='train')
341
 
342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  if __name__ == '__main__':
344
  login(token=os.environ.get("HF_TOKEN"))
345
-
346
- if 'roster' not in st.session_state:
347
- print('loading roster')
348
- # st.session_state.roster = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferRoster', split='train'))
349
- st.session_state.roster = pd.DataFrame(load_from_disk(os.path.join(os.getcwd(), 'data', 'roster')))
350
- st.session_state.roster = st.session_state.roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name',
351
- 'model_download_count']].drop_duplicates().reset_index(drop=True)
352
- # add model download count from roster to promptbook dataframe
353
- if 'promptBook' not in st.session_state:
354
- print('loading promptBook')
355
-
356
- st.session_state.promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train'))
357
- # add 'checked' column to promptBook if not exist
358
- if 'checked' not in st.session_state.promptBook.columns:
359
- st.session_state.promptBook.loc[:, 'checked'] = False
360
-
361
- # add 'custom_score_weights' column to promptBook if not exist
362
- if 'weighted_score_sum' not in st.session_state.promptBook.columns:
363
- st.session_state.promptBook.loc[:, 'weighted_score_sum'] = 0
364
-
365
- st.session_state.images = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
366
- # st.session_state.images = load_dataset('NYUSHPRP/ModelCofferPromptBook', split='train', streaming=True)
367
- print(st.session_state.images)
368
- print('images loaded')
369
- # st.session_state.promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferPromptBook', split='train'))
370
- st.session_state.promptBook = st.session_state.promptBook.merge(st.session_state.roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']], on=['model_id', 'modelVersion_id'], how='left')
371
-
372
- # add column to record current row index
373
- st.session_state.promptBook['row_idx'] = st.session_state.promptBook.index
374
- print('promptBook loaded')
375
- # print(st.session_state.promptBook)
376
-
377
- check_roster_error = False
378
- if check_roster_error:
379
- # print all rows with the same model_id and modelVersion_id but different model_download_count in roster
380
- print(st.session_state.roster[st.session_state.roster.duplicated(subset=['model_id', 'modelVersion_id'], keep=False)].sort_values(by=['model_id', 'modelVersion_id']))
381
-
382
- app = GalleryApp(promptBook=st.session_state.promptBook)
 
 
 
 
383
  app.app()
 
3
  import random
4
  import pandas as pd
5
  import glob
 
6
  from PIL import Image
7
  import datasets
8
  from datasets import load_dataset, Dataset, load_from_disk
 
12
  from bs4 import BeautifulSoup
13
  import re
14
 
15
+ import altair as alt
16
+ from streamlit_vega_lite import vega_lite_component, altair_component, _component_func
17
+
18
  SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'avg_rank', 'pop': 'model_download_count'}
19
 
20
 
21
+ # hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["x"])
22
+ @st.cache_resource
23
+ def altair_histogram(hist_data, sort_by):
24
+ brushed = alt.selection_interval(encodings=['x'], name="brushed")
25
+ return (
26
+ alt.Chart(hist_data)
27
+ .mark_bar()
28
+ .encode(alt.X(f"{sort_by}:Q", bin=True), y="count()")
29
+ .add_selection(brushed)
30
+ .properties(width=600, height=300)
31
+ )
32
+
33
  class GalleryApp:
34
+ def __init__(self, promptBook, images_ds):
35
  self.promptBook = promptBook
36
+ self.images_ds = images_ds
37
 
38
  def gallery_masonry(self, items, col_num, info):
39
  cols = st.columns(col_num)
 
41
  # items = items.sort_values(by=['brisque'], ascending=True).reset_index(drop=True)
42
  for idx in range(len(items)):
43
  with cols[idx % col_num]:
44
+ image = self.images_ds[items.iloc[idx]['row_idx'].item()]['image']
45
  st.image(image,
46
  use_column_width=True,
47
  )
 
72
  if idx + j < len(items):
73
  with cols[j]:
74
  # show image
75
+ image = self.images_ds[items.iloc[idx+j]['row_idx'].item()]['image']
76
 
77
  st.image(image,
78
  use_column_width=True,
 
198
  with sub_selecters[2]:
199
  pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=1.0, step=0.1)
200
 
201
+ items.loc[:, 'weighted_score_sum'] = round(items['norm_clip'] * clip_weight + items['avg_rank'] * rank_weight + items[
202
+ 'norm_pop'] * pop_weight, 4)
203
 
204
  continue_idx = 3
205
 
206
+
207
  with sub_selecters[continue_idx]:
208
  order = st.selectbox('Order', ['Ascending', 'Descending'], index=1 if sort_type == 'Scores' else 0)
209
  if order == 'Ascending':
 
226
  items = items[items['checked'] == True].reset_index(drop=True)
227
  print(items)
228
 
229
+ if sort_type == 'Scores':
230
+ st.write('Select the range of scores to show')
231
+ hist_data = pd.DataFrame(items[sort_by])
232
+ event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
233
+ r = event_dict.get(sort_by)
234
+ if r:
235
+ items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
236
+ st.write(r)
237
+
238
  info = st.multiselect('Show Info',
239
  ['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
240
  'modelVersion_name', 'modelVersion_id', 'clip+rank', 'clip+pop', 'rank+pop',
 
327
 
328
  if safety_check:
329
  items, info, col_num = self.selection_panel_2(items)
330
+
331
  # self.gallery_standard(items, col_num, info)
332
 
333
  with st.form(key=f'{prompt_id}', clear_on_submit=False):
 
365
  dataset.push_to_hub('NYUSHPRP/ModelCofferMetadata', split='train')
366
 
367
 
368
+ @st.cache_data
369
+ def load_hf_dataset():
370
+ # load from huggingface
371
+ roster = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferRoster', split='train'))
372
+ promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train'))
373
+ images_ds = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
374
+
375
+ # process dataset
376
+ roster = roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name',
377
+ 'model_download_count']].drop_duplicates().reset_index(drop=True)
378
+
379
+ # add 'checked' column to promptBook if not exist
380
+ if 'checked' not in promptBook.columns:
381
+ promptBook.loc[:, 'checked'] = False
382
+
383
+ # add 'custom_score_weights' column to promptBook if not exist
384
+ if 'weighted_score_sum' not in promptBook.columns:
385
+ promptBook.loc[:, 'weighted_score_sum'] = 0
386
+
387
+ # merge roster and promptbook
388
+ promptBook = promptBook.merge(roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']],
389
+ on=['model_id', 'modelVersion_id'], how='left')
390
+
391
+ # add column to record current row index
392
+ promptBook.loc[:, 'row_idx'] = promptBook.index
393
+
394
+ return roster, promptBook, images_ds
395
+
396
+
397
  if __name__ == '__main__':
398
  login(token=os.environ.get("HF_TOKEN"))
399
+ st.set_page_config(layout="wide")
400
+
401
+ # if 'roster' not in st.session_state:
402
+ # print('loading roster')
403
+ # # st.session_state.roster = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferRoster', split='train'))
404
+ # st.session_state.roster = pd.DataFrame(load_from_disk(os.path.join(os.getcwd(), 'data', 'roster')))
405
+ # st.session_state.roster = st.session_state.roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name',
406
+ # 'model_download_count']].drop_duplicates().reset_index(drop=True)
407
+ # # add model download count from roster to promptbook dataframe
408
+ # if 'promptBook' not in st.session_state:
409
+ # print('loading promptBook')
410
+ #
411
+ # st.session_state.promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train'))
412
+ # # add 'checked' column to promptBook if not exist
413
+ # if 'checked' not in st.session_state.promptBook.columns:
414
+ # st.session_state.promptBook.loc[:, 'checked'] = False
415
+ #
416
+ # # add 'custom_score_weights' column to promptBook if not exist
417
+ # if 'weighted_score_sum' not in st.session_state.promptBook.columns:
418
+ # st.session_state.promptBook.loc[:, 'weighted_score_sum'] = 0
419
+ #
420
+ # st.session_state.images = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
421
+ # # st.session_state.images = load_dataset('NYUSHPRP/ModelCofferPromptBook', split='train', streaming=True)
422
+ # print(st.session_state.images)
423
+ # print('images loaded')
424
+ # # st.session_state.promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferPromptBook', split='train'))
425
+ # st.session_state.promptBook = st.session_state.promptBook.merge(st.session_state.roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']], on=['model_id', 'modelVersion_id'], how='left')
426
+ #
427
+ # # add column to record current row index
428
+ # st.session_state.promptBook['row_idx'] = st.session_state.promptBook.index
429
+ # print('promptBook loaded')
430
+ # # print(st.session_state.promptBook)
431
+ #
432
+ # check_roster_error = False
433
+ # if check_roster_error:
434
+ # # print all rows with the same model_id and modelVersion_id but different model_download_count in roster
435
+ # print(st.session_state.roster[st.session_state.roster.duplicated(subset=['model_id', 'modelVersion_id'], keep=False)].sort_values(by=['model_id', 'modelVersion_id']))
436
+ roster, promptBook, images_ds = load_hf_dataset()
437
+ # if 'images' not in st.session_state:
438
+ # st.session_state.images = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
439
+
440
+ app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
441
  app.app()
requirements.txt CHANGED
@@ -2,4 +2,5 @@ huggingface_hub
2
  streamlit-elements==0.1.0
3
  streamlit-extras
4
  altair<5
5
- streamlit-plotly-events
 
 
2
  streamlit-elements==0.1.0
3
  streamlit-extras
4
  altair<5
5
+ streamlit-plotly-events
6
+ streamlit-vega-lite