Spaces:
Sleeping
Sleeping
File size: 7,143 Bytes
bca2bcb 8ff0942 bca2bcb 8ff0942 4933968 bca2bcb 8ff0942 00a6576 8ff0942 00a6576 8ff0942 00a6576 8ff0942 00a6576 8ff0942 00a6576 8ff0942 00a6576 8ff0942 00a6576 8ff0942 bca2bcb 3f0bdca bca2bcb 3f0bdca bca2bcb 3f0bdca bca2bcb 00a6576 4933968 00a6576 5d76d94 c6516ad 3f0bdca bca2bcb 5d76d94 00a6576 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import numpy as np
import pandas as pd
import streamlit as st
from streamlit_elements import elements, mui, html, dashboard, nivo
from streamlit_extras.switch_page_button import switch_page
from pages.Gallery import load_hf_dataset
class RankingApp:
def __init__(self, promptBook, images_endpoint, batch_size=4):
self.promptBook = promptBook
self.images_endpoint = images_endpoint
self.batch_size = batch_size
# self.batch_num = len(self.promptBook) // self.batch_size
# self.batch_num += 1 if len(self.promptBook) % self.batch_size != 0 else 0
if 'counter' not in st.session_state:
st.session_state.counter = 0
def sidebar(self):
with st.sidebar:
prompt_tags = self.promptBook['tag'].unique()
prompt_tags = np.sort(prompt_tags)
tag = st.selectbox('Select a prompt tag', prompt_tags)
items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
prompts = np.sort(items['prompt'].unique())[::-1]
selected_prompt = st.selectbox('Select a prompt', prompts)
items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
prompt_id = items['prompt_id'].unique()[0]
with st.form(key='prompt_form'):
# input image metadata
prompt = st.text_area('Prompt', selected_prompt, height=150, key='prompt', disabled=True)
negative_prompt = st.text_area('Negative Prompt', items['negativePrompt'].unique()[0], height=150, key='negative_prompt', disabled=True)
st.form_submit_button('Generate Images', type='primary', use_container_width=True)
return prompt_tags, tag, prompt_id, items
def draggable_images(self, items, layout='portrait'):
# init ranking by the order of items
if 'ranking' not in st.session_state:
st.session_state.ranking = {}
for i in range(len(items)):
st.session_state.ranking[str(items['image_id'][i])] = i
print(items)
with elements('dashboard'):
if layout == 'portrait':
col_num = 4
layout = [dashboard.Item(str(items['image_id'][i]), i % col_num, i//col_num, 1, 2, isResizable=False) for i in range(len(items))]
elif layout == 'landscape':
col_num = 2
layout = [
dashboard.Item(str(items['image_id'][i]), i % col_num * 2, i // col_num, 2, 1.4, isResizable=False) for
i in range(len(items))
]
with dashboard.Grid(layout, cols={'lg': 4, 'md': 4, 'sm': 4, 'xs': 4, 'xxs': 2}, onLayoutChange=self.handle_layout_change, margin=[18, 18], containerPadding=[0, 0]):
for i in range(len(layout)):
with mui.Card(key=str(items['image_id'][i]), variant="outlined"):
rank = st.session_state.ranking[str(items['image_id'][i])] + 1
mui.Chip(label=rank,
# variant="outlined" if rank!=1 else "default",
color="primary" if rank == 1 else "warning" if rank == 2 else "info",
size="small",
sx={"position": "absolute", "left": "-0.3rem", "top": "-0.3rem"})
img_url = self.images_endpoint + str(items['image_id'][i]) + '.png'
mui.CardMedia(
component="img",
# image={"data:image/png;base64", img_str},
image=img_url,
alt="There should be an image",
sx={"height": "100%", "object-fit": "fit", 'bgcolor': 'black'},
)
def handle_layout_change(self, updated_layout):
# print(updated_layout)
sorted_list = sorted(updated_layout, key=lambda x: (x['y'], x['x']))
sorted_list = [str(item['i']) for item in sorted_list]
for k in st.session_state.ranking.keys():
st.session_state.ranking[k] = sorted_list.index(k)
def app(self):
st.title('Personal Image Ranking')
st.write('Here you can test out your selected images with any prompt you like.')
# st.write(self.promptBook)
prompt_tags, tag, prompt_id, items = self.sidebar()
sorting, control = st.columns((11, 1), gap='large')
with sorting:
# st.write('## Sorting')
# st.write('Please drag the images to sort them.')
st.progress((st.session_state.counter + 1) / self.batch_num, text=f"Batch {st.session_state.counter + 1} / {self.batch_num}")
self.draggable_images(items.iloc[self.batch_size*st.session_state.counter: self.batch_size*(st.session_state.counter+1)], layout='portrait')
with control:
st.button(":arrow_right:")
st.button(":slightly_frowning_face:")
if __name__ == "__main__":
st.set_page_config(page_title="Personal Image Ranking", page_icon="🎖️️", layout="wide")
if 'user_id' not in st.session_state:
st.warning('Please log in first.')
home_btn = st.button('Go to Home Page')
if home_btn:
switch_page("home")
else:
selected_modelVersions = []
for key, value in st.session_state.selected_dict.items():
for v in value:
if v not in selected_modelVersions:
selected_modelVersions.append(v)
if len(selected_modelVersions) == 0:
st.info('You have not checked any image yet. Please go back to the gallery page and check some images.')
gallery_btn = st.button('Go to Gallery')
if gallery_btn:
switch_page('gallery')
else:
# st.write('You have checked ' + str(len(selected_modelVersions)) + ' images.')
roster, promptBook, images_ds = load_hf_dataset()
print(st.session_state.selected_dict)
st.write("# Full function is coming soon.")
st.write("## roster")
st.write(roster[roster['modelVersion_id'].isin(selected_modelVersions)])
# st.write(roster)
# st.write("## promptBook")
# st.write(promptBook)
# # only select the part of the promptbook where tag is the same as st.session_state.selected_dict.keys(), while model version ids are the same as corresponding values to each key
# promptBook_selected = pd.DataFrame()
# for key, value in st.session_state.selected_dict.items():
# promptBook_selected = promptBook_selected.append(promptBook[(promptBook['prompt_id'] == key) & (promptBook['modelVersion_id'].isin(value))])
# promptBook_selected = promptBook_selected.reset_index(drop=True)
# images_endpoint = "https://modelcofferbucket.s3-accelerate.amazonaws.com/"
#
# app = RankingApp(promptBook_selected, images_endpoint, batch_size=4)
# app.app()
|