File size: 7,143 Bytes
bca2bcb
 
8ff0942
 
 
bca2bcb
8ff0942
4933968
bca2bcb
8ff0942
00a6576
 
8ff0942
00a6576
 
 
 
8ff0942
00a6576
 
8ff0942
 
 
 
00a6576
8ff0942
00a6576
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ff0942
 
 
 
00a6576
8ff0942
00a6576
 
 
 
 
 
 
 
 
 
 
 
8ff0942
 
bca2bcb
 
 
 
 
 
 
 
 
 
3f0bdca
bca2bcb
 
3f0bdca
 
bca2bcb
3f0bdca
bca2bcb
 
 
 
 
00a6576
4933968
00a6576
5d76d94
 
 
c6516ad
 
3f0bdca
 
bca2bcb
5d76d94
 
 
 
 
 
 
 
 
00a6576
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import numpy as np
import pandas as pd
import streamlit as st

from streamlit_elements import elements, mui, html, dashboard, nivo
from streamlit_extras.switch_page_button import switch_page

from pages.Gallery import load_hf_dataset


class RankingApp:
    def __init__(self, promptBook, images_endpoint, batch_size=4):
        self.promptBook = promptBook
        self.images_endpoint = images_endpoint
        self.batch_size = batch_size
        # self.batch_num = len(self.promptBook) // self.batch_size
        # self.batch_num += 1 if len(self.promptBook) % self.batch_size != 0 else 0

        if 'counter' not in st.session_state:
            st.session_state.counter = 0

    def sidebar(self):
        with st.sidebar:
            prompt_tags = self.promptBook['tag'].unique()
            prompt_tags = np.sort(prompt_tags)

            tag = st.selectbox('Select a prompt tag', prompt_tags)
            items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
            prompts = np.sort(items['prompt'].unique())[::-1]

            selected_prompt = st.selectbox('Select a prompt', prompts)

            items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
            prompt_id = items['prompt_id'].unique()[0]

            with st.form(key='prompt_form'):
                # input image metadata
                prompt = st.text_area('Prompt', selected_prompt, height=150, key='prompt', disabled=True)
                negative_prompt = st.text_area('Negative Prompt', items['negativePrompt'].unique()[0], height=150, key='negative_prompt', disabled=True)
                st.form_submit_button('Generate Images', type='primary', use_container_width=True)

        return prompt_tags, tag, prompt_id, items

    def draggable_images(self, items, layout='portrait'):
        # init ranking by the order of items
        if 'ranking' not in st.session_state:
            st.session_state.ranking = {}
            for i in range(len(items)):
                st.session_state.ranking[str(items['image_id'][i])] = i

        print(items)
        with elements('dashboard'):
            if layout == 'portrait':
                col_num = 4
                layout = [dashboard.Item(str(items['image_id'][i]), i % col_num, i//col_num, 1, 2, isResizable=False) for i in range(len(items))]

            elif layout == 'landscape':
                col_num = 2
                layout = [
                    dashboard.Item(str(items['image_id'][i]), i % col_num * 2, i // col_num, 2, 1.4, isResizable=False) for
                    i in range(len(items))
                ]

            with dashboard.Grid(layout, cols={'lg': 4, 'md': 4, 'sm': 4, 'xs': 4, 'xxs': 2}, onLayoutChange=self.handle_layout_change, margin=[18, 18], containerPadding=[0, 0]):
                for i in range(len(layout)):
                    with mui.Card(key=str(items['image_id'][i]), variant="outlined"):
                        rank = st.session_state.ranking[str(items['image_id'][i])] + 1

                        mui.Chip(label=rank,
                                 # variant="outlined" if rank!=1 else "default",
                                 color="primary" if rank == 1 else "warning" if rank == 2 else "info",
                                 size="small",
                                 sx={"position": "absolute", "left": "-0.3rem", "top": "-0.3rem"})

                        img_url = self.images_endpoint + str(items['image_id'][i]) + '.png'

                        mui.CardMedia(
                            component="img",
                            # image={"data:image/png;base64", img_str},
                            image=img_url,
                            alt="There should be an image",
                            sx={"height": "100%", "object-fit": "fit", 'bgcolor': 'black'},
                        )

    def handle_layout_change(self, updated_layout):
        # print(updated_layout)
        sorted_list = sorted(updated_layout, key=lambda x: (x['y'], x['x']))
        sorted_list = [str(item['i']) for item in sorted_list]

        for k in st.session_state.ranking.keys():
            st.session_state.ranking[k] = sorted_list.index(k)

    def app(self):
        st.title('Personal Image Ranking')
        st.write('Here you can test out your selected images with any prompt you like.')
        # st.write(self.promptBook)

        prompt_tags, tag, prompt_id, items = self.sidebar()

        sorting, control = st.columns((11, 1), gap='large')
        with sorting:
            # st.write('## Sorting')
            # st.write('Please drag the images to sort them.')
            st.progress((st.session_state.counter + 1) / self.batch_num, text=f"Batch {st.session_state.counter + 1} / {self.batch_num}")
            self.draggable_images(items.iloc[self.batch_size*st.session_state.counter: self.batch_size*(st.session_state.counter+1)], layout='portrait')

        with control:
            st.button(":arrow_right:")
            st.button(":slightly_frowning_face:")


if __name__ == "__main__":
    st.set_page_config(page_title="Personal Image Ranking", page_icon="🎖️️", layout="wide")

    if 'user_id' not in st.session_state:
        st.warning('Please log in first.')
        home_btn = st.button('Go to Home Page')
        if home_btn:
            switch_page("home")

    else:
        selected_modelVersions = []
        for key, value in st.session_state.selected_dict.items():
            for v in value:
                if v not in selected_modelVersions:
                    selected_modelVersions.append(v)

        if len(selected_modelVersions) == 0:
            st.info('You have not checked any image yet. Please go back to the gallery page and check some images.')
            gallery_btn = st.button('Go to Gallery')
            if gallery_btn:
                switch_page('gallery')
        else:
            # st.write('You have checked ' + str(len(selected_modelVersions)) + ' images.')
            roster, promptBook, images_ds = load_hf_dataset()
            print(st.session_state.selected_dict)
            st.write("# Full function is coming soon.")
            st.write("## roster")
            st.write(roster[roster['modelVersion_id'].isin(selected_modelVersions)])

            # st.write(roster)
            # st.write("## promptBook")
            # st.write(promptBook)

            # # only select the part of the promptbook where tag is the same as st.session_state.selected_dict.keys(), while model version ids are the same as corresponding values to each key
            # promptBook_selected = pd.DataFrame()
            # for key, value in st.session_state.selected_dict.items():
            #     promptBook_selected = promptBook_selected.append(promptBook[(promptBook['prompt_id'] == key) & (promptBook['modelVersion_id'].isin(value))])
            # promptBook_selected = promptBook_selected.reset_index(drop=True)
            # images_endpoint = "https://modelcofferbucket.s3-accelerate.amazonaws.com/"
            #
            # app = RankingApp(promptBook_selected, images_endpoint, batch_size=4)
            # app.app()