import streamlit as st import numpy as np import random import pandas as pd import glob import csv from PIL import Image from datasets import load_dataset, Dataset, load_from_disk from huggingface_hub import login import os class GalleryApp: def __init__(self, promptBook): self.promptBook = promptBook st.set_page_config(layout="wide") def gallery(self, items, col_num, info): cols = st.columns(col_num) # # sort items by brisque score # items = items.sort_values(by=['brisque'], ascending=True).reset_index(drop=True) for idx in range(len(items)): with cols[idx % col_num]: image = st.session_state.images[items.iloc[idx]['row_idx'].item()]['image'] st.image(image, use_column_width=True, ) for key in info: st.write(f"**{key}**: {items.iloc[idx][key]}") def app(self): st.title('Model Coffer Gallery') st.write('This is a gallery of images generated by the models in the Model Coffer') metadata, images = st.columns([1, 3]) # with images: # prompt_tags = self.promptBook['tag'].unique() # # sort tags by alphabetical order # prompt_tags = np.sort(prompt_tags) # # selecters = st.columns(3) # with selecters[0]: # tag = st.selectbox('Select a tag', prompt_tags) with metadata: prompt_tags = self.promptBook['tag'].unique() # sort tags by alphabetical order prompt_tags = np.sort(prompt_tags) tag = st.selectbox('Select a tag', prompt_tags) items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True) prompts = [', '.join(x.split(', ')[4:]) for x in items['prompt'].unique()] prompt = st.selectbox('Select prompt', prompts) idx = prompts.index(prompt) prompt_full = ', '.join(items['prompt'].unique()[idx].split(', ')[:4]) + ', ' + prompt prompt_id = items[items['prompt'] == prompt_full]['prompt_id'].unique()[0] items = items[items['prompt_id'] == prompt_id].reset_index(drop=True) st.write('**Prompt ID**') st.caption(f"{prompt_id}") st.write('**Prompt**') st.caption(f"{items['prompt'][0]}") st.write('**Negative Prompt**') st.caption(f"{items['negativePrompt'][0]}") st.write('**Sampler**') st.caption(f"{items['sampler'][0]}") st.write('**cfgScale**') st.caption(f"{items['cfgScale'][0]}") st.write('**Size**') st.caption(f"width: {items['size'][0].split('x')[0]}, height: {items['size'][0].split('x')[1]}") st.write('**Seed**') st.caption(f"{items['seed'][0]}") with images: selecters = st.columns([1, 1, 2]) # with selecters[0]: with selecters[0]: sort_by = st.selectbox('Sort by', items.columns[11: -1]) with selecters[1]: order = st.selectbox('Order', ['Ascending', 'Descending'], index=1 if sort_by == 'clip_score' or sort_by == 'model_download_count' else 0) if order == 'Ascending': order = True else: order = False items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True) with selecters[2]: info = st.multiselect('Show Info', ['brisque_score', 'clip_score', 'model_download_count', 'model_name', 'model_id', 'modelVersion_name', 'modelVersion_id'], default=sort_by) # if sort_by not in info: # info.append(sort_by) col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num') self.gallery(items, col_num, info) # st.write(items) if __name__ == '__main__': login(token=os.environ.get("HF_TOKEN")) if 'roster' not in st.session_state: print('loading roster') # st.session_state.roster = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferRoster', split='train')) st.session_state.roster = pd.DataFrame(load_from_disk(os.path.join(os.getcwd(), 'data', 'roster'))) st.session_state.roster = st.session_state.roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']].drop_duplicates().reset_index(drop=True) # add model download count from roster to promptbook dataframe if 'promptBook' not in st.session_state: print('loading promptBook') st.session_state.promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train')) st.session_state.images = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook')) print('images loaded') # st.session_state.promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferPromptBook', split='train')) st.session_state.promptBook = st.session_state.promptBook.merge(st.session_state.roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']], on=['model_id', 'modelVersion_id'], how='left') # add column to record current row index st.session_state.promptBook['row_idx'] = st.session_state.promptBook.index print('promptBook loaded') # print(st.session_state.promptBook) check_roster_error = False if check_roster_error: # print all rows with the same model_id and modelVersion_id but different model_download_count in roster print(st.session_state.roster[st.session_state.roster.duplicated(subset=['model_id', 'modelVersion_id'], keep=False)].sort_values(by=['model_id', 'modelVersion_id'])) app = GalleryApp(promptBook=st.session_state.promptBook) app.app()