Ricercar commited on
Commit
6bdebc7
1 Parent(s): d471ad5

update bar chart

Browse files
Files changed (2) hide show
  1. app.py +46 -52
  2. test_altair.py +22 -47
app.py CHANGED
@@ -20,16 +20,38 @@ SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'avg_rank', 'pop': 'model_do
20
 
21
  # hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["x"])
22
  @st.cache_resource
23
- def altair_histogram(hist_data, sort_by):
24
  brushed = alt.selection_interval(encodings=['x'], name="brushed")
25
- return (
 
26
  alt.Chart(hist_data)
27
- .mark_bar()
28
- .encode(alt.X(f"{sort_by}:Q", bin=True), y="count()")
29
- .add_selection(brushed)
30
- .properties(width=600, height=300)
 
 
 
 
 
 
 
 
31
  )
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  class GalleryApp:
34
  def __init__(self, promptBook, images_ds):
35
  self.promptBook = promptBook
@@ -169,7 +191,6 @@ class GalleryApp:
169
 
170
  return items, info, col_num
171
 
172
-
173
  def selection_panel_2(self, items):
174
  selecters = st.columns([1, 5])
175
 
@@ -226,14 +247,25 @@ class GalleryApp:
226
  items = items[items['checked'] == True].reset_index(drop=True)
227
  print(items)
228
 
 
229
  if sort_type == 'Scores':
230
- st.write('Select the range of scores to show')
231
- hist_data = pd.DataFrame(items[sort_by])
232
- event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
233
- r = event_dict.get(sort_by)
234
- if r:
235
- items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
236
- st.write(r)
 
 
 
 
 
 
 
 
 
 
237
 
238
  info = st.multiselect('Show Info',
239
  ['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
@@ -308,7 +340,6 @@ class GalleryApp:
308
  except:
309
  pass
310
 
311
-
312
  # add safety check for some prompts
313
  safety_check = True
314
  unsafe_prompts = {}
@@ -398,44 +429,7 @@ if __name__ == '__main__':
398
  login(token=os.environ.get("HF_TOKEN"))
399
  st.set_page_config(layout="wide")
400
 
401
- # if 'roster' not in st.session_state:
402
- # print('loading roster')
403
- # # st.session_state.roster = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferRoster', split='train'))
404
- # st.session_state.roster = pd.DataFrame(load_from_disk(os.path.join(os.getcwd(), 'data', 'roster')))
405
- # st.session_state.roster = st.session_state.roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name',
406
- # 'model_download_count']].drop_duplicates().reset_index(drop=True)
407
- # # add model download count from roster to promptbook dataframe
408
- # if 'promptBook' not in st.session_state:
409
- # print('loading promptBook')
410
- #
411
- # st.session_state.promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train'))
412
- # # add 'checked' column to promptBook if not exist
413
- # if 'checked' not in st.session_state.promptBook.columns:
414
- # st.session_state.promptBook.loc[:, 'checked'] = False
415
- #
416
- # # add 'custom_score_weights' column to promptBook if not exist
417
- # if 'weighted_score_sum' not in st.session_state.promptBook.columns:
418
- # st.session_state.promptBook.loc[:, 'weighted_score_sum'] = 0
419
- #
420
- # st.session_state.images = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
421
- # # st.session_state.images = load_dataset('NYUSHPRP/ModelCofferPromptBook', split='train', streaming=True)
422
- # print(st.session_state.images)
423
- # print('images loaded')
424
- # # st.session_state.promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferPromptBook', split='train'))
425
- # st.session_state.promptBook = st.session_state.promptBook.merge(st.session_state.roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']], on=['model_id', 'modelVersion_id'], how='left')
426
- #
427
- # # add column to record current row index
428
- # st.session_state.promptBook['row_idx'] = st.session_state.promptBook.index
429
- # print('promptBook loaded')
430
- # # print(st.session_state.promptBook)
431
- #
432
- # check_roster_error = False
433
- # if check_roster_error:
434
- # # print all rows with the same model_id and modelVersion_id but different model_download_count in roster
435
- # print(st.session_state.roster[st.session_state.roster.duplicated(subset=['model_id', 'modelVersion_id'], keep=False)].sort_values(by=['model_id', 'modelVersion_id']))
436
  roster, promptBook, images_ds = load_hf_dataset()
437
- # if 'images' not in st.session_state:
438
- # st.session_state.images = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
439
 
440
  app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
441
  app.app()
 
20
 
21
  # hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["x"])
22
  @st.cache_resource
23
+ def altair_histogram(hist_data, sort_by, mini, maxi):
24
  brushed = alt.selection_interval(encodings=['x'], name="brushed")
25
+
26
+ chart = (
27
  alt.Chart(hist_data)
28
+ .mark_bar(opacity=0.7, cornerRadius=2)
29
+ .encode(alt.X(f"{sort_by}:Q", bin=alt.Bin(maxbins=20)), y="count()")
30
+ # .add_selection(brushed)
31
+ # .properties(width=800, height=300)
32
+ )
33
+
34
+ # Create a transparent rectangle for highlighting the range
35
+ highlight = (
36
+ alt.Chart(pd.DataFrame({'x1': [mini], 'x2': [maxi]}))
37
+ .mark_rect(opacity=0.3)
38
+ .encode(x='x1', x2='x2')
39
+ # .properties(width=800, height=300)
40
  )
41
 
42
+ # Layer the chart and the highlight rectangle
43
+ layered_chart = alt.layer(chart, highlight)
44
+
45
+ return layered_chart
46
+
47
+ # return (
48
+ # alt.Chart(hist_data)
49
+ # .mark_bar()
50
+ # .encode(alt.X(f"{sort_by}:Q", bin=alt.Bin(maxbins=20)), y="count()")
51
+ # .add_selection(brushed)
52
+ # .properties(width=600, height=300)
53
+ # )
54
+
55
  class GalleryApp:
56
  def __init__(self, promptBook, images_ds):
57
  self.promptBook = promptBook
 
191
 
192
  return items, info, col_num
193
 
 
194
  def selection_panel_2(self, items):
195
  selecters = st.columns([1, 5])
196
 
 
247
  items = items[items['checked'] == True].reset_index(drop=True)
248
  print(items)
249
 
250
+ # draw a distribution histogram
251
  if sort_type == 'Scores':
252
+ with st.expander('Show score distribution histogram and select score range'):
253
+ st.write('**Score distribution histogram**')
254
+ chart_space = st.container()
255
+ # st.write('Select the range of scores to show')
256
+ hist_data = pd.DataFrame(items[sort_by])
257
+ mini = hist_data[sort_by].min().item()
258
+ maxi = hist_data[sort_by].max().item()
259
+ st.write('**Select the range of scores to show**')
260
+ r = st.slider('Select the range of scores to show', min_value=mini, max_value=maxi, value=(mini, maxi), label_visibility='collapsed')
261
+ with chart_space:
262
+ st.altair_chart(altair_histogram(hist_data, sort_by, r[0], r[1]), use_container_width=True)
263
+ # event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
264
+ # r = event_dict.get(sort_by)
265
+ if r:
266
+ items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
267
+ # st.write(r)
268
+
269
 
270
  info = st.multiselect('Show Info',
271
  ['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
 
340
  except:
341
  pass
342
 
 
343
  # add safety check for some prompts
344
  safety_check = True
345
  unsafe_prompts = {}
 
429
  login(token=os.environ.get("HF_TOKEN"))
430
  st.set_page_config(layout="wide")
431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  roster, promptBook, images_ds = load_hf_dataset()
 
 
433
 
434
  app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
435
  app.app()
test_altair.py CHANGED
@@ -1,50 +1,25 @@
1
- import altair as alt
2
  import streamlit as st
 
3
  import pandas as pd
4
- import numpy as np
5
-
6
- from streamlit_vega_lite import vega_lite_component, altair_component, _component_func
7
-
8
- hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["abc"])
9
- print(hist_data)
10
-
11
- @st.cache_resource
12
- def altair_histogram():
13
- brushed = alt.selection_interval(encodings=["x"], name="brushed")
14
-
15
- return (
16
- alt.Chart(hist_data)
17
- .mark_bar()
18
- .encode(alt.X("abc:Q", bin=True), y="count()")
19
- .add_selection(brushed)
20
- )
21
-
22
- chart = altair_histogram()
23
- res = st.altair_chart(chart, use_container_width=True)
24
- # print(res)
25
- event_dict = altair_component(altair_chart=altair_histogram())
26
- chart_dict = chart.to_dict()
27
- print(chart_dict)
28
- altair_chart = chart.copy()
29
- datasets = {}
30
-
31
- def id_transform(data):
32
- """Altair data transformer that returns a fake named dataset with the
33
- object id."""
34
- name = f"d{id(data)}"
35
- datasets[name] = data
36
- return {"name": name}
37
-
38
- alt.data_transformers.register("id", id_transform)
39
-
40
- with alt.data_transformers.enable("id"):
41
- chart_dict = altair_chart.to_dict()
42
- # st.write(event_dict)
43
-
44
- event_dict = _component_func(spec=chart_dict, **datasets, key=None, default={})
45
- # print(chart_dict)
46
 
47
- r = event_dict.get("abc")
48
- if r:
49
- filtered = hist_data[(hist_data.abc >= r[0]) & (hist_data.abc < r[1])]
50
- st.write(filtered)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import altair as alt
3
  import pandas as pd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ # Generate random data for the chart
6
+ data = pd.DataFrame({
7
+ 'Category': ['A', 'B', 'C', 'D', 'E'],
8
+ 'Value': [0.2, 0.5, 0.8, 1.2, 1.5]
9
+ })
10
+
11
+ # Define the color scale for the bars
12
+ color_scale = alt.Scale(
13
+ domain=[0, 1], # Values between 0 and 1 will be blue
14
+ range=['steelblue', 'lightgray']
15
+ )
16
+
17
+ # Create the bar chart using Altair
18
+ chart = alt.Chart(data).mark_bar().encode(
19
+ x='Category',
20
+ y='Value',
21
+ color=alt.Color('Value', scale=color_scale)
22
+ )
23
+
24
+ # Render the chart using Streamlit
25
+ st.altair_chart(chart, use_container_width=True)