dvir-bria commited on
Commit
95a9f0f
1 Parent(s): f50c7eb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
2
+ from diffusers.utils import load_image
3
+ from PIL import Image
4
+ import torch
5
+ import numpy as np
6
+ import cv2
7
+ import gradio as gr
8
+
9
+ controlnet_conditioning_scale = 0.5 # recommended for good generalization
10
+
11
+ controlnet = ControlNetModel.from_pretrained(
12
+ "briaai/ControlNet-Canny",
13
+ torch_dtype=torch.float16
14
+ )
15
+
16
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
17
+ "briaai/BRIA-2.0",
18
+ controlnet=controlnet,
19
+ vae=vae,
20
+ torch_dtype=torch.float16,
21
+ )
22
+ pipe.enable_model_cpu_offload()
23
+
24
+ low_threshold = 100
25
+ high_threshold = 200
26
+
27
+ def get_canny_filter(image):
28
+
29
+ if not isinstance(image, np.ndarray):
30
+ image = np.array(image)
31
+
32
+ image = cv2.Canny(image, low_threshold, high_threshold)
33
+ image = image[:, :, None]
34
+ image = np.concatenate([image, image, image], axis=2)
35
+ canny_image = Image.fromarray(image)
36
+ return canny_image
37
+
38
+ def process(input_image, prompt):
39
+ canny_image = get_canny_filter(input_image)
40
+ images = pipe(
41
+ prompt,image=canny_image, controlnet_conditioning_scale=controlnet_conditioning_scale,
42
+ ).images
43
+
44
+ return [canny_image,images[0]]
45
+
46
+ block = gr.Blocks().queue()
47
+
48
+ with block:
49
+ gr.Markdown("## BRIA 2.0 ControlNet Canny")
50
+ gr.HTML('''
51
+ <p style="margin-bottom: 10px; font-size: 94%">
52
+ This is a demo for BRIA 2.0 ControlNet Canny, a fully legal and safe T2I model.
53
+ </p>
54
+ ''')
55
+ with gr.Row():
56
+ with gr.Column():
57
+ input_image = gr.Image(source='upload', type="numpy")
58
+ prompt = gr.Textbox(label="Prompt")
59
+ run_button = gr.Button(label="Run")
60
+
61
+
62
+ with gr.Column():
63
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid_cols=2, height='auto')
64
+ ips = [input_image, prompt]
65
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
66
+
67
+ block.launch(debug = True)