cosmicman commited on
Commit
dc988ac
1 Parent(s): fa4b160

Upload demo_sdxl.py

Browse files
Files changed (1) hide show
  1. demo_sdxl.py +212 -0
demo_sdxl.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, UNet2DConditionModel
2
+ from diffusers.utils import load_image
3
+ from diffusers import (
4
+ DDIMScheduler,
5
+ PNDMScheduler,
6
+ LMSDiscreteScheduler,
7
+ EulerDiscreteScheduler,
8
+ EulerAncestralDiscreteScheduler,
9
+ DPMSolverMultistepScheduler,
10
+ )
11
+ import torch
12
+ import os
13
+ import random
14
+ import numpy as np
15
+ from PIL import Image
16
+ from typing import Tuple
17
+ import gradio as gr
18
+ DESCRIPTION = """
19
+ # CosmicMan
20
+ - CosmicMan: A Text-to-Image Foundation Model for Humans (CVPR 2024 (Highlight))
21
+ """
22
+
23
+ if not torch.cuda.is_available():
24
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
25
+
26
+ schedule_map = {
27
+ "ddim" : DDIMScheduler,
28
+ "pndm" : PNDMScheduler,
29
+ "lms" : LMSDiscreteScheduler,
30
+ "euler" : EulerDiscreteScheduler,
31
+ "euler_a": EulerAncestralDiscreteScheduler,
32
+ "dpm" : DPMSolverMultistepScheduler,
33
+ }
34
+
35
+ examples = [
36
+ "A fit Caucasian elderly woman, her wavy white hair above shoulders, wears a pink floral cotton long-sleeve shirt and a cotton hat against a natural landscape in an upper body shot",
37
+ "A closeup of a doll with a purple ribbon around her neck, best quality, extremely detailed",
38
+ "A closeup of a girl with a butterfly painted on her face",
39
+ "A headshot, an asian elderly male, a blue wall, bald above eyes gray hair",
40
+ "A closeup portrait shot against a white wall, a fit Caucasian adult female with wavy blonde hair falling above her chest wears a short sleeve silk floral dress and a floral silk normal short sleeve white blouse",
41
+ "A headshot, an adult caucasian male, fit, a white wall, red crew cut curly hair, short sleeve normal blue t-shirt, best quality, extremely detailed",
42
+ "A closeup of a man wearing a red shirt with a flower design on it",
43
+ "There is a man wearing a mask and holding a cell phone",
44
+ "Two boys playing in the yard",
45
+ ]
46
+
47
+ style_list = [
48
+ {
49
+ "name": "(No style)",
50
+ "prompt": "{prompt}",
51
+ "negative_prompt": "",
52
+ },
53
+ {
54
+ "name": "Cinematic",
55
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
56
+ "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
57
+ },
58
+ {
59
+ "name": "Photographic",
60
+ "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
61
+ "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
62
+ },
63
+ {
64
+ "name": "Anime",
65
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
66
+ "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
67
+ },
68
+ {
69
+ "name": "Fantasy art",
70
+ "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
71
+ "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
72
+ },
73
+ {
74
+ "name": "Neonpunk",
75
+ "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
76
+ "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
77
+ }
78
+ ]
79
+
80
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
81
+ STYLE_NAMES = list(styles.keys())
82
+ DEFAULT_STYLE_NAME = "(No style)"
83
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
84
+ MAX_SEED = np.iinfo(np.int32).max
85
+ NUM_IMAGES_PER_PROMPT = 1
86
+
87
+ def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
88
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
89
+ if not negative:
90
+ negative = ""
91
+ return p.replace("{prompt}", positive), n + negative
92
+
93
+ class NoWatermark:
94
+ def apply_watermark(self, img):
95
+ return img
96
+
97
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
98
+ if randomize_seed:
99
+ seed = random.randint(0, MAX_SEED)
100
+ return seed
101
+
102
+ print("Loading Model!")
103
+ schedule: str = "euler_a"
104
+ base_model_path: str = "stabilityai/stable-diffusion-xl-base-1.0"
105
+ refiner_model_path: str = "stabilityai/stable-diffusion-xl-refiner-1.0"
106
+ unet_path: str = "cosmicman/CosmicMan-SDXL"
107
+ SCHEDULER = schedule_map[schedule]
108
+ scheduler = SCHEDULER.from_pretrained(base_model_path, subfolder="scheduler", torch_dtype=torch.float16)
109
+ unet = UNet2DConditionModel.from_pretrained(unet_path, torch_dtype=torch.float16)
110
+
111
+ pipe = StableDiffusionXLPipeline.from_pretrained(
112
+ base_model_path,
113
+ unet=unet,
114
+ scheduler=scheduler,
115
+ torch_dtype=torch.float16,
116
+ use_safetensors=True
117
+ ).to("cuda")
118
+ pipe.watermark = NoWatermark()
119
+ refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
120
+ base_model_path, # we found use base_model_path instead of refiner_model_path will get a better performance
121
+ scheduler=scheduler,
122
+ torch_dtype=torch.float16, use_safetensors=True
123
+ ).to("cuda")
124
+ refiner.watermark = NoWatermark()
125
+ print("Finish Loading Model!")
126
+
127
+ def generate_image(prompt,
128
+ n_prompt="",
129
+ style: str = DEFAULT_STYLE_NAME,
130
+ steps: int = 50,
131
+ height: int = 1024,
132
+ width: int = 1024,
133
+ scale: float = 7.5,
134
+ img_num: int = 4,
135
+ seeds: int = 42,
136
+ random_seed: bool = False,
137
+ ):
138
+ print("Beign to generate")
139
+ image_list = []
140
+ for i in range(img_num):
141
+ generator = torch.Generator(device="cuda")
142
+ seed = int(randomize_seed_fn(seeds, random_seed))
143
+ generator = torch.Generator().manual_seed(seed)
144
+ positive_prompt, negative_prompt = apply_style(style, prompt, n_prompt)
145
+ image = pipe(positive_prompt, num_inference_steps=steps,
146
+ guidance_scale=scale, height=height,
147
+ width=width, negative_prompt=negative_prompt,
148
+ generator=generator, output_type="latent").images[0]
149
+ image = refiner(positive_prompt, negative_prompt=negative_prompt, image=image[None, :]).images[0]
150
+ image_list.append((image,f"Seed {seed}"))
151
+ return image_list
152
+
153
+ with gr.Blocks(theme=gr.themes.Soft(),css="style.css") as demo:
154
+ gr.Markdown(DESCRIPTION)
155
+ with gr.Group():
156
+ with gr.Row():
157
+ with gr.Column():
158
+ input_prompt = gr.Textbox(label="Input prompt", lines=3, max_lines=5)
159
+ negative_prompt = gr.Textbox(label="Negative prompt",value="")
160
+ run_button = gr.Button("Run", scale=0)
161
+ result = gr.Gallery(label="Result", show_label=False, elem_id="gallery", columns=[2], rows=[2], object_fit="contain", height="auto")
162
+ with gr.Accordion("Advanced options", open=False):
163
+ with gr.Row():
164
+ style_selection = gr.Radio(
165
+ show_label=True,
166
+ container=True,
167
+ interactive=True,
168
+ choices=STYLE_NAMES,
169
+ value=DEFAULT_STYLE_NAME,
170
+ label="Image Style",
171
+ )
172
+ with gr.Row():
173
+ height = gr.Slider(minimum=512, maximum=1536, value=1024, label="Height", step=64)
174
+ width = gr.Slider(minimum=512, maximum=1536, value=1024, label="Witdh", step=64)
175
+ with gr.Row():
176
+ steps = gr.Slider(minimum=1, maximum=50, value=30, label="Number of diffusion steps", step=1)
177
+ scale = gr.Number(minimum=1, maximum=12, value=7.5, label="Number of scale")
178
+ with gr.Row():
179
+ seed = gr.Slider(
180
+ label="Seed",
181
+ minimum=0,
182
+ maximum=MAX_SEED,
183
+ step=1,
184
+ value=0,
185
+ )
186
+ random_seed = gr.Checkbox(label="Randomize seed", value=True)
187
+ img_num = gr.Slider(minimum=1, maximum=4, value=4, label="Number of images", step=1)
188
+
189
+ gr.Examples(
190
+ examples=examples,
191
+ inputs=input_prompt,
192
+ outputs=result,
193
+ fn=generate_image,
194
+ cache_examples=CACHE_EXAMPLES,
195
+ )
196
+
197
+ gr.on(
198
+ triggers=[
199
+ input_prompt.submit,
200
+ negative_prompt.submit,
201
+ run_button.click,
202
+ ],
203
+ fn=generate_image,
204
+ inputs = [input_prompt, negative_prompt, style_selection, steps, height, width, scale, img_num, seed, random_seed],
205
+ outputs= result,
206
+ api_name="run")
207
+
208
+
209
+ if __name__ == "__main__":
210
+ demo.queue(max_size=20)
211
+ demo.launch(share=True, server_name='0.0.0.0', server_port=10057)
212
+