import os import streamlit as st import torch import pandas as pd import numpy as np from datasets import load_dataset, Dataset, load_from_disk from huggingface_hub import login from streamlit_agraph import agraph, Node, Edge, Config from sklearn.manifold import TSNE @st.cache_data def load_hf_dataset(): # login to huggingface login(token=os.environ.get("HF_TOKEN")) # load from huggingface roster = pd.DataFrame(load_dataset('MAPS-research/GEMRec-Roster', split='train')) promptBook = pd.DataFrame(load_dataset('MAPS-research/GEMRec-Metadata', split='train')) # process dataset roster = roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']].drop_duplicates().reset_index(drop=True) # add 'custom_score_weights' column to promptBook if not exist if 'weighted_score_sum' not in promptBook.columns: promptBook.loc[:, 'weighted_score_sum'] = 0 # merge roster and promptbook promptBook = promptBook.merge(roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']], on=['model_id', 'modelVersion_id'], how='left') # add column to record current row index promptBook.loc[:, 'row_idx'] = promptBook.index return roster, promptBook @st.cache_data def calc_tsne(prompt_id): print('==> loading feats') feats = {} for pt in os.listdir('../data/feats'): if pt.split('.')[-1] == 'pt' and pt.split('.')[0].isdigit(): feats[pt.split('.')[0]] = torch.load(os.path.join('../data/feats', pt)) print('==> applying t-SNE') # apply t-SNE to entries in each feat in feats to get 2D coordinates tsne = TSNE(n_components=2, random_state=0) # for k, v in tqdm(feats.items()): # feats[k]['tsne'] = tsne.fit_transform(v['all'].numpy()) # prompt_id = '90' feats[prompt_id]['tsne'] = tsne.fit_transform(feats[prompt_id]['all'].numpy()) feats_df = pd.DataFrame(feats[prompt_id]['tsne'], columns=['x', 'y']) feats_df['prompt_id'] = prompt_id keys = [] for k in feats[prompt_id].keys(): if k != 'all' and k != 'tsne': keys.append(int(k.item())) feats_df['modelVersion_id'] = keys return feats_df # print(feats[prompt_id]['tsne']) if __name__ == '__main__': st.set_page_config(layout="wide") # load dataset roster, promptBook = load_hf_dataset() # prompt_id = '20' with st.sidebar: st.write('## Select Prompt') prompts = promptBook['prompt_id'].unique().tolist() # sort prompts by prompt_id prompts.sort() prompt_id = st.selectbox('Select Prompt', prompts, index=0) physics = st.checkbox('Enable Physics') feats_df = calc_tsne(str(prompt_id)) # keys = [] # for k in feats[prompt_id].keys(): # if k != 'all' and k != 'tsne': # keys.append(int(k.item())) # print(keys) data = [] for idx in feats_df.index: modelVersion_id = feats_df.loc[idx, 'modelVersion_id'] image_id = promptBook[(promptBook['modelVersion_id'] == modelVersion_id) & ( promptBook['prompt_id'] == int(prompt_id))].reset_index(drop=True).loc[0, 'image_id'] image_url = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{image_id}.png" scale = 50 data.append((feats_df.loc[idx, 'x'] * scale, feats_df.loc[idx, 'y'] * scale, image_url)) image_size = promptBook[(promptBook['image_id'] == image_id)].reset_index(drop=True).loc[0, 'size'].split('x') nodes = [] edges = [] for d in data: nodes.append( Node(id=d[2], # label=str(items.loc[idx, 'model_name']), size=20, shape="image", image=d[2], x=[d[0]], y=[d[1]], fixed=False if physics else True, color={'background': '#00000', 'border': '#ffffff'}, shadow={'enabled': True, 'color': 'rgba(0,0,0,0.4)', 'size': 10, 'x': 1, 'y': 1}, # borderWidth=1, # shapeProperties={'useBorderWithImage': True}, ) ) # nodes.append( Node(id="Spiderman", # label="Peter Parker", # size=25, # shape="circularImage", # image="http://marvel-force-chart.surge.sh/marvel_force_chart_img/top_spiderman.png") # ) # includes **kwargs # nodes.append( Node(id="Captain_Marvel", # label="Carol Danvers", # fixed=True, # size=25, # shape="circularImage", # image="http://marvel-force-chart.surge.sh/marvel_force_chart_img/top_captainmarvel.png") # ) # edges.append( Edge(source="Captain_Marvel", # label="friend_of", # target="Spiderman", # length=200, # # **kwargs # ) # ) # config = Config(width='100%', height=800, directed=True, physics=physics, hierarchical=False, # **kwargs ) cols = st.columns([3, 1], gap='large') with cols[0]: return_value = agraph(nodes=nodes, edges=edges, config=config) # st.write(return_value) with cols[1]: try: st.image(return_value, use_column_width=True) except: st.write('No image selected')