lx0xl commited on
Commit
1daa995
1 Parent(s): db36c8c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -0
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import requests
4
+ import time
5
+ import json
6
+ import base64
7
+ import os
8
+ from PIL import Image
9
+ from io import BytesIO
10
+
11
+ class Prodia:
12
+ def __init__(self, api_key, base=None):
13
+ self.base = base or "https://api.prodia.com/v1"
14
+ self.headers = {
15
+ "X-Prodia-Key": api_key
16
+ }
17
+
18
+ def generate(self, params):
19
+ response = self._post(f"{self.base}/sdxl/generate", params)
20
+ return response.json()
21
+
22
+ def get_job(self, job_id):
23
+ response = self._get(f"{self.base}/job/{job_id}")
24
+ return response.json()
25
+
26
+ def wait(self, job):
27
+ job_result = job
28
+
29
+ while job_result['status'] not in ['succeeded', 'failed']:
30
+ time.sleep(0.25)
31
+ job_result = self.get_job(job['job'])
32
+
33
+ return job_result
34
+
35
+ def list_models(self):
36
+ response = self._get(f"{self.base}/sdxl/models")
37
+ return response.json()
38
+
39
+ def list_samplers(self):
40
+ response = self._get(f"{self.base}/sdxl/samplers")
41
+ return response.json()
42
+
43
+
44
+ def generate_v2(self, config):
45
+ response = self._post("https://inference.prodia.com/v2/job", {"type": "v2.job.sdxl.txt2img", "config": config}, v2=True)
46
+ return Image.open(BytesIO(response.content)).convert("RGBA")
47
+
48
+
49
+ def _post(self, url, params, v2=False):
50
+ headers = {
51
+ **self.headers,
52
+ "Content-Type": "application/json"
53
+ }
54
+ if v2:
55
+ headers['Authorization'] = f"Bearer {os.getenv('API_KEY')}"
56
+
57
+ response = requests.post(url, headers=headers, data=json.dumps(params))
58
+
59
+ if response.status_code != 200:
60
+ raise Exception(f"Bad Prodia Response: {response.status_code}")
61
+
62
+ return response
63
+
64
+ def _get(self, url):
65
+ response = requests.get(url, headers=self.headers)
66
+
67
+ if response.status_code != 200:
68
+ raise Exception(f"Bad Prodia Response: {response.status_code}")
69
+
70
+ return response
71
+
72
+
73
+
74
+
75
+ def image_to_base64(image_path):
76
+ # Open the image with PIL
77
+ with Image.open(image_path) as image:
78
+ # Convert the image to bytes
79
+ buffered = BytesIO()
80
+ image.save(buffered, format="PNG") # You can change format to PNG if needed
81
+
82
+ # Encode the bytes to base64
83
+ img_str = base64.b64encode(buffered.getvalue())
84
+
85
+ return img_str.decode('utf-8') # Convert bytes to string
86
+
87
+
88
+
89
+ prodia_client = Prodia(api_key=os.getenv("PRODIA_API_KEY"))
90
+
91
+ def flip_text(prompt, negative_prompt, model, steps, sampler, cfg_scale, resolution, seed):
92
+
93
+ width, height = resolution.split("x")
94
+
95
+ config_without_model_and_sampler = {
96
+ "prompt": prompt,
97
+ "negative_prompt": negative_prompt,
98
+ "steps": steps,
99
+ "cfg_scale": cfg_scale,
100
+ "width": int(width),
101
+ "height": int(height),
102
+ "seed": seed
103
+ }
104
+
105
+ # 本条注释替换成下面那条,避免sd_xl_base_1.0.safetensors [be9edd61]使用API_KEY导致的无法使用
106
+ # if model == "sd_xl_base_1.0.safetensors [be9edd61]":
107
+
108
+ if model == "xxxxx":
109
+ return prodia_client.generate_v2(config_without_model_and_sampler)
110
+
111
+ result = prodia_client.generate({
112
+ **config_without_model_and_sampler,
113
+ "model": model,
114
+ "sampler": sampler
115
+ })
116
+
117
+ job = prodia_client.wait(result)
118
+
119
+ return job["imageUrl"]
120
+
121
+ css = """
122
+ #generate {
123
+ height: 100%;
124
+ }
125
+ """
126
+
127
+ list_resolutions = [
128
+ "512x512",
129
+ "640x960",
130
+ "800x1200",
131
+ "1280x720",
132
+ "1368x768",
133
+ "1024x1024",
134
+ "1216x832",
135
+ "1344x768",
136
+ "1536x640",
137
+ "640x1536",
138
+ "768x1344",
139
+ "832x1216"
140
+ ]
141
+
142
+ with gr.Blocks(css=css) as demo:
143
+
144
+
145
+ with gr.Row():
146
+ with gr.Column(scale=6):
147
+ model = gr.Dropdown(interactive=True,value="animagineXLV3_v30.safetensors [75f2f05b]", show_label=True, label="Stable Diffusion Checkpoint", choices=prodia_client.list_models())
148
+
149
+ with gr.Column(scale=1):
150
+ gr.Markdown(elem_id="powered-by-prodia", value="AUTOMATIC1111 Stable Diffusion Web UI for SDXL V1.0.<br>Powered by [Prodia](https://prodia.com).")
151
+
152
+ with gr.Tab("txt2img"):
153
+ with gr.Row():
154
+ with gr.Column(scale=6, min_width=600):
155
+ prompt = gr.Textbox("(masterpiece,highres,best quality,8k),1girl,solo,space warrior, ultrarealistic, soft lighting", placeholder="Prompt", show_label=False, lines=3)
156
+ negative_prompt = gr.Textbox(placeholder="Negative Prompt", show_label=False, lines=3, value="(nsfw:1.2),lowres,[bad anatomy,bad hands,missing fingers,long neck],text,error")
157
+ with gr.Column():
158
+ text_button = gr.Button("Generate", variant='primary', elem_id="generate")
159
+
160
+ with gr.Row():
161
+ with gr.Column(scale=3):
162
+ with gr.Tab("Generation"):
163
+ with gr.Row():
164
+ with gr.Column(scale=1):
165
+ sampler = gr.Dropdown(value="DPM++ 2M Karras", show_label=True, label="Sampling Method", choices=prodia_client.list_samplers())
166
+
167
+ with gr.Column(scale=1):
168
+ steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=50, value=25, step=1)
169
+
170
+ with gr.Row():
171
+ with gr.Column(scale=1):
172
+ resolution = gr.Dropdown(value="800x1200", show_label=True, label="Resolution", choices=list_resolutions)
173
+
174
+ with gr.Column(scale=1):
175
+ batch_size = gr.Slider(label="Batch Size", maximum=1, value=1)
176
+ batch_count = gr.Slider(label="Batch Count", maximum=1, value=1)
177
+
178
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, value=10, step=1)
179
+ seed = gr.Number(label="Seed", value=-1)
180
+
181
+ with gr.Column(scale=2):
182
+ # image_output = gr.Image(value="https://cdn-uploads.huggingface.co/production/uploads/noauth/XWJyh9DhMGXrzyRJk7SfP.png")
183
+ image_output = gr.Image(value="./image.png")
184
+
185
+ text_button.click(flip_text, inputs=[prompt, negative_prompt, model, steps, sampler, cfg_scale, resolution, seed], outputs=image_output)
186
+
187
+ demo.queue(default_concurrency_limit=1, max_size=32, api_open=False).launch(max_threads=128, auth=(os.getenv("USERNAME"), os.getenv("PASSWORD")))