sketch2lineart / app.py
tori29umai's picture
app.py
61542fc
raw
history blame
No virus
4.93 kB
import spaces
import gradio as gr
from gradio_imageslider import ImageSlider
import torch
torch.jit.script = lambda f: f
from diffusers import (
ControlNetModel,
StableDiffusionXLControlNetImg2ImgPipeline,
DDIMScheduler,
)
from controlnet_aux import AnylineDetector
from compel import Compel, ReturnedEmbeddingsType
from PIL import Image
import os
import time
import numpy as np
from utils.utils import load_cn_model, load_cn_config, load_tagger_model, load_lora_model, resize_image_aspect_ratio, base_generation
from utils.prompt_analysis import PromptAnalysis
path = os.getcwd()
cn_dir = f"{path}/controlnet"
os.makedirs(cn_dir)
tagger_dir = f"{path}/tagger"
os.mkdir(tagger_dir)
lora_dir = f"{path}/lora"
os.mkdir(lora_dir)
load_cn_model(cn_dir)
load_cn_config(cn_dir)
load_tagger_model(tagger_dir)
load_lora_model(lora_dir)
IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1"
IS_SPACE = os.environ.get("SPACE_ID", None) is not None
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16
LOW_MEMORY = os.getenv("LOW_MEMORY", "0") == "1"
print(f"device: {device}")
print(f"dtype: {dtype}")
print(f"low memory: {LOW_MEMORY}")
model = "cagliostrolab/animagine-xl-3.1"
scheduler = DDIMScheduler.from_pretrained(model, subfolder="scheduler")
controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=torch.float16, use_safetensors=True)
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
model,
controlnet=controlnet,
torch_dtype=dtype,
use_safetensors=True,
scheduler=scheduler,
)
pipe.load_lora_weights(
lora_dir,
weight_name="sdxl_BWLine.safetensors"
)
compel = Compel(
tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True],
)
pipe = pipe.to(device)
@spaces.GPU
def predict(
input_image_path,
prompt,
negative_prompt,
controlnet_conditioning_scale,
):
input_image_pil = Image.open(input_image_path)
base_size =input_image_pil.size
resize_image= resize_image_aspect_ratio(input_image_pil)
resize_image_size = resize_image.size
width = resize_image_size[0]
height = resize_image_size[1]
white_base_pil = base_generation(resize_image.size, (255, 255, 255, 255)).convert("RGB")
conditioning, pooled = compel([prompt, negative_prompt])
generator = torch.manual_seed(0)
last_time = time.time()
output_image = pipe(
image=white_base_pil,
control_image=resize_image,
strength=1.0,
prompt_embeds=conditioning[0:1],
pooled_prompt_embeds=pooled[0:1],
negative_prompt_embeds=conditioning[1:2],
negative_pooled_prompt_embeds=pooled[1:2],
width=width,
height=height,
controlnet_conditioning_scale=float(controlnet_conditioning_scale),
controlnet_start=0.0,
controlnet_end=1.0,
generator=generator,
num_inference_steps=30,
guidance_scale=8.5,
eta=1.0,
)
print(f"Time taken: {time.time() - last_time}")
output_image = output_image.resize(base_size, Image.LANCZOS)
return output_image
css = """
#intro{
# max-width: 32rem;
# text-align: center;
# margin: 0 auto;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Row() as block:
with gr.Column():
# 画像アップロード用の行
with gr.Row():
with gr.Column():
input_image_path = gr.Image(label="入力画像", type='filepath')
# プロンプト入力用の行
with gr.Row():
prompt_analysis = PromptAnalysis(tagger_dir)
[prompt, nega] = prompt_analysis.layout(input_image_path)
# 画像の詳細設定用のスライダー行
with gr.Row():
controlnet_conditioning_scale = gr.Slider(minimum=0.5, maximum=1.25, value=1.0, step=0.01, interactive=True, label="線画忠実度")
# 画像生成ボタンの行
with gr.Row():
generate_button = gr.Button("生成", interactive=False)
with gr.Column():
output_image = gr.Image(type="pil", label="Output Image")
# インプットとアウトプットの設定
inputs = [
input_image_path,
prompt,
nega,
controlnet_conditioning_scale,
]
outputs = [output_image]
# ボタンのクリックイベントを設定
generate_button.click(
fn=predict,
inputs=[input_image_path, prompt, nega, controlnet_conditioning_scale],
outputs=[output_image]
)
# デモの設定と起動
demo.queue(api_open=True)
demo.launch(show_api=True)