File size: 3,830 Bytes
95a9f0f
 
 
 
 
 
292ed4d
3638fca
95a9f0f
 
3638fca
95a9f0f
 
f817fc9
95a9f0f
 
 
 
 
dfd3180
63a0180
95a9f0f
 
 
 
3638fca
4b7b010
3638fca
 
 
 
 
 
 
 
95a9f0f
 
 
 
 
 
 
 
 
 
 
63a0180
 
 
3638fca
 
 
95a9f0f
63a0180
 
f817fc9
 
 
 
 
95a9f0f
 
 
 
 
 
 
 
fb8ab9c
 
 
 
 
 
95a9f0f
 
41bd23c
95a9f0f
63a0180
0f13dc2
 
63a0180
efd6737
95a9f0f
 
 
6852b3e
63a0180
95a9f0f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
from diffusers.utils import load_image
from PIL import Image
import torch
import numpy as np
import cv2
import gradio as gr
from torchvision import transforms 

controlnet = ControlNetModel.from_pretrained(
    "briaai/ControlNet-Canny",
    torch_dtype=torch.float16
)
# force_zeros_for_empty_prompt=False

pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "briaai/BRIA-2.0",
    controlnet=controlnet,
    torch_dtype=torch.float16,
)#.cuda()
pipe.enable_xformers_memory_efficient_attention()

low_threshold = 100
high_threshold = 200

def resize_image(image):
    image = image.convert('RGB')
    current_size = image.size
    if current_size[0] > current_size[1]:
        center_cropped_image = transforms.functional.center_crop(image, (current_size[1], current_size[1]))
    else:
        center_cropped_image = transforms.functional.center_crop(image, (current_size[0], current_size[0]))
    resized_image = transforms.functional.resize(center_cropped_image, (1024, 1024))
    return resized_image

def get_canny_filter(image):
    
    if not isinstance(image, np.ndarray):
        image = np.array(image) 
        
    image = cv2.Canny(image, low_threshold, high_threshold)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    canny_image = Image.fromarray(image)
    return canny_image

def process(input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed):
    generator = torch.manual_seed(seed)
    
    # resize input_image to 1024x1024
    input_image = resize_image(input_image)
    
    canny_image = get_canny_filter(input_image)

    pipe.enable_freeu(b1=1.1, b2=1.1, s1=0.5, s2=0.7)
  
    images = pipe(
        prompt, negative_prompt=negative_prompt, image=canny_image, num_inference_steps=num_steps, controlnet_conditioning_scale=float(controlnet_conditioning_scale),
        generator=generator,
        ).images

    return [canny_image,images[0]]
    
block = gr.Blocks().queue()

with block:
    gr.Markdown("## BRIA 2.0 ControlNet Canny")
    gr.HTML('''
      <p style="margin-bottom: 10px; font-size: 94%">
        This is a demo for ControlNet Canny that using
        <a href="https://huggingface.co/briaai/BRIA-2.0" target="_blank">BRIA 2.0 text-to-image model</a> as backbone. 
        Trained on licensed data, BRIA 2.0 provide full legal liability coverage for copyright and privacy infringement.
      </p>
    ''')
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
            prompt = gr.Textbox(label="Prompt")
            negative_prompt = gr.Textbox(label="Negative prompt", value="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers")
            num_steps = gr.Slider(label="Number of steps", minimum=25, maximum=100, value=50, step=1)
            controlnet_conditioning_scale = gr.Slider(label="ControlNet conditioning scale", minimum=0.1, maximum=2.0, value=1.0, step=0.05)
            seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True,)
            run_button = gr.Button(value="Run")
            
            
        with gr.Column():
            result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[2], height='auto')
    ips = [input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed]
    run_button.click(fn=process, inputs=ips, outputs=[result_gallery])

block.launch(debug = True)