tori29umai commited on
Commit
2b32e3d
1 Parent(s): 881ee5b
Files changed (1) hide show
  1. app.py +77 -117
app.py CHANGED
@@ -1,99 +1,98 @@
1
  import spaces
2
  import gradio as gr
3
  import torch
4
-
5
- torch.jit.script = lambda f: f
6
- from diffusers import (
7
- ControlNetModel,
8
- StableDiffusionXLControlNetImg2ImgPipeline,
9
- DDIMScheduler,
10
- )
11
  from compel import Compel, ReturnedEmbeddingsType
12
  from PIL import Image
13
  import os
14
  import time
15
- import numpy as np
16
 
17
  from utils.utils import load_cn_model, load_cn_config, load_tagger_model, load_lora_model, resize_image_aspect_ratio, base_generation
18
  from utils.prompt_analysis import PromptAnalysis
19
 
20
- path = os.getcwd()
21
- cn_dir = f"{path}/controlnet"
22
- os.makedirs(cn_dir)
23
- tagger_dir = f"{path}/tagger"
24
- os.mkdir(tagger_dir)
25
- lora_dir = f"{path}/lora"
26
- os.mkdir(lora_dir)
27
-
28
- load_cn_model(cn_dir)
29
- load_cn_config(cn_dir)
30
- load_tagger_model(tagger_dir)
31
- load_lora_model(lora_dir)
32
-
33
- IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1"
34
- IS_SPACE = os.environ.get("SPACE_ID", None) is not None
35
-
36
- device = "cuda" if torch.cuda.is_available() else "cpu"
37
- dtype = torch.float16
38
-
39
- LOW_MEMORY = os.getenv("LOW_MEMORY", "0") == "1"
40
-
41
- print(f"device: {device}")
42
- print(f"dtype: {dtype}")
43
- print(f"low memory: {LOW_MEMORY}")
44
-
45
-
46
- model = "cagliostrolab/animagine-xl-3.1"
47
- scheduler = DDIMScheduler.from_pretrained(model, subfolder="scheduler")
48
- controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=torch.float16, use_safetensors=True)
49
- pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
50
- model,
51
- controlnet=controlnet,
52
- torch_dtype=dtype,
53
- use_safetensors=True,
54
- scheduler=scheduler,
55
- )
56
-
57
- pipe.load_lora_weights(
58
- lora_dir,
59
- weight_name="sdxl_BWLine.safetensors"
60
- )
61
-
62
-
63
- compel = Compel(
64
- tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
65
- text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
66
- returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
67
- requires_pooled=[False, True],
68
- )
69
- pipe = pipe.to(device)
70
-
71
-
72
-
73
  class Img2Img:
74
  def __init__(self):
75
- self.input_image_path = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  @spaces.GPU
78
- def predict(
79
- self,
80
- input_image_path,
81
- prompt,
82
- negative_prompt,
83
- controlnet_conditioning_scale,
84
- ):
85
  input_image_pil = Image.open(input_image_path)
86
- base_size =input_image_pil.size
87
- resize_image= resize_image_aspect_ratio(input_image_pil)
88
  resize_image_size = resize_image.size
89
- width = resize_image_size[0]
90
- height = resize_image_size[1]
91
  white_base_pil = base_generation(resize_image.size, (255, 255, 255, 255)).convert("RGB")
92
- conditioning, pooled = compel([prompt, negative_prompt])
93
  generator = torch.manual_seed(0)
94
  last_time = time.time()
95
 
96
- output_image = pipe(
97
  image=white_base_pil,
98
  control_image=resize_image,
99
  strength=1.0,
@@ -103,7 +102,7 @@ class Img2Img:
103
  negative_pooled_prompt_embeds=pooled[1:2],
104
  width=width,
105
  height=height,
106
- controlnet_conditioning_scale=float(controlnet_conditioning_scale),
107
  controlnet_start=0.0,
108
  controlnet_end=1.0,
109
  generator=generator,
@@ -115,44 +114,5 @@ class Img2Img:
115
  output_image = output_image.resize(base_size, Image.LANCZOS)
116
  return output_image
117
 
118
-
119
- css = """
120
- #intro{
121
- # max-width: 32rem;
122
- # text-align: center;
123
- # margin: 0 auto;
124
- }
125
- """
126
- def layout(self,css):
127
- with gr.Blocks(css=css) as demo:
128
- with gr.Column():
129
- # 画像アップロード用の行
130
- with gr.Row():
131
- with gr.Column():
132
- self.input_image_path = gr.Image(label="入力画像", type='filepath')
133
-
134
- # プロンプト入力用の行
135
- with gr.Row():
136
- prompt_analysis = PromptAnalysis(tagger_dir)
137
- [prompt, nega] = prompt_analysis.layout(self.input_image_path)
138
- # 画像の詳細設定用のスライダー行
139
- with gr.Row():
140
- controlnet_conditioning_scale = gr.Slider(minimum=0.5, maximum=1.25, value=1.0, step=0.01, interactive=True, label="線画忠実度")
141
-
142
- # 画像生成ボタンの行
143
- with gr.Row():
144
- generate_button = gr.Button("生成", interactive=False)
145
-
146
- with gr.Column():
147
- output_image = gr.Image(type="pil", label="Output Image")
148
-
149
- # ボタンのクリックイベントを設定
150
- generate_button.click(
151
- fn=self.predict,
152
- inputs=[self.input_image_path, prompt, nega, controlnet_conditioning_scale],
153
- outputs=[output_image]
154
- )
155
-
156
- # デモの設定と起動
157
- demo.queue(api_open=True)
158
- demo.launch(show_api=True)
 
1
  import spaces
2
  import gradio as gr
3
  import torch
4
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, DDIMScheduler
 
 
 
 
 
 
5
  from compel import Compel, ReturnedEmbeddingsType
6
  from PIL import Image
7
  import os
8
  import time
 
9
 
10
  from utils.utils import load_cn_model, load_cn_config, load_tagger_model, load_lora_model, resize_image_aspect_ratio, base_generation
11
  from utils.prompt_analysis import PromptAnalysis
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  class Img2Img:
14
  def __init__(self):
15
+ self.setup_paths()
16
+ self.setup_models()
17
+ self.compel = self.setup_compel()
18
+ self.demo = self.layout()
19
+
20
+ def setup_paths(self):
21
+ self.path = os.getcwd()
22
+ self.cn_dir = f"{self.path}/controlnet"
23
+ self.tagger_dir = f"{self.path}/tagger"
24
+ self.lora_dir = f"{self.path}/lora"
25
+ os.makedirs(self.cn_dir, exist_ok=True)
26
+ os.makedirs(self.tagger_dir, exist_ok=True)
27
+ os.makedirs(self.lora_dir, exist_ok=True)
28
+
29
+ def setup_models(self):
30
+ load_cn_model(self.cn_dir)
31
+ load_cn_config(self.cn_dir)
32
+ load_tagger_model(self.tagger_dir)
33
+ load_lora_model(self.lora_dir)
34
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
35
+ self.dtype = torch.float16
36
+ self.model = "cagliostrolab/animagine-xl-3.1"
37
+ self.scheduler = DDIMScheduler.from_pretrained(self.model, subfolder="scheduler")
38
+ self.controlnet = ControlNetModel.from_pretrained(self.cn_dir, torch_dtype=self.dtype, use_safetensors=True)
39
+ self.pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
40
+ self.model,
41
+ controlnet=self.controlnet,
42
+ torch_dtype=self.dtype,
43
+ use_safetensors=True,
44
+ scheduler=self.scheduler,
45
+ )
46
+ self.pipe.load_lora_weights(self.lora_dir, weight_name="sdxl_BWLine.safetensors")
47
+ self.pipe = self.pipe.to(self.device)
48
+
49
+ def setup_compel(self):
50
+ return Compel(
51
+ tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
52
+ text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
53
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
54
+ requires_pooled=[False, True],
55
+ )
56
+
57
+ def layout(self):
58
+ css = """
59
+ #intro{
60
+ max-width: 32rem;
61
+ text-align: center;
62
+ margin: 0 auto;
63
+ }
64
+ """
65
+ with gr.Blocks(css=css) as demo:
66
+ with gr.Row():
67
+ with gr.Column():
68
+ self.input_image_path = gr.Image(label="入力画像", type='filepath')
69
+ self.prompt_analysis = PromptAnalysis(self.tagger_dir)
70
+ self.prompt, self.negative_prompt = self.prompt_analysis.layout(self.input_image_path)
71
+ self.controlnet_scale = gr.Slider(minimum=0.5, maximum=1.25, value=1.0, step=0.01, label="線画忠実度")
72
+ generate_button = gr.Button("生成")
73
+ with gr.Column():
74
+ self.output_image = gr.Image(type="pil", label="生成画像")
75
+
76
+ generate_button.click(
77
+ fn=self.predict,
78
+ inputs=[self.input_image_path, self.prompt, self.negative_prompt, self.controlnet_scale],
79
+ outputs=self.output_image
80
+ )
81
+ return demo
82
 
83
  @spaces.GPU
84
+ def predict(self, input_image_path, prompt, negative_prompt, controlnet_scale):
 
 
 
 
 
 
85
  input_image_pil = Image.open(input_image_path)
86
+ base_size = input_image_pil.size
87
+ resize_image = resize_image_aspect_ratio(input_image_pil)
88
  resize_image_size = resize_image.size
89
+ width, height = resize_image_size
 
90
  white_base_pil = base_generation(resize_image.size, (255, 255, 255, 255)).convert("RGB")
91
+ conditioning, pooled = self.compel([prompt, negative_prompt])
92
  generator = torch.manual_seed(0)
93
  last_time = time.time()
94
 
95
+ output_image = self.pipe(
96
  image=white_base_pil,
97
  control_image=resize_image,
98
  strength=1.0,
 
102
  negative_pooled_prompt_embeds=pooled[1:2],
103
  width=width,
104
  height=height,
105
+ controlnet_conditioning_scale=float(controlnet_scale),
106
  controlnet_start=0.0,
107
  controlnet_end=1.0,
108
  generator=generator,
 
114
  output_image = output_image.resize(base_size, Image.LANCZOS)
115
  return output_image
116
 
117
+ img2img = Img2Img()
118
+ img2img.demo.launch()