LoCo / app.py
Pusheen's picture
Update app.py
127d448 verified
raw
history blame
No virus
22.4 kB
# import os
# os.environ["CUDA_VISIBLE_DEVICES"]="4"
import gradio as gr
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, LMSDiscreteScheduler
from my_model import unet_2d_condition
import json
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from functools import partial
import math
from utils import compute_loco_v2
from gradio import processing_utils
from typing import Optional
import warnings
import sys
sys.tracebacklimit = 0
class Blocks(gr.Blocks):
def __init__(
self,
theme: str = "default",
analytics_enabled: Optional[bool] = None,
mode: str = "blocks",
title: str = "Gradio",
css: Optional[str] = None,
**kwargs,
):
self.extra_configs = {
'thumbnail': kwargs.pop('thumbnail', ''),
'url': kwargs.pop('url', 'https://gradio.app/'),
'creator': kwargs.pop('creator', '@teamGradio'),
}
super(Blocks, self).__init__(theme, analytics_enabled, mode, title, css, **kwargs)
warnings.filterwarnings("ignore")
def get_config_file(self):
config = super(Blocks, self).get_config_file()
for k, v in self.extra_configs.items():
config[k] = v
return config
def draw_box(boxes=[], texts=[], img=None):
if len(boxes) == 0 and img is None:
return None
if img is None:
img = Image.new('RGB', (512, 512), (255, 255, 255))
colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
draw = ImageDraw.Draw(img)
font = ImageFont.truetype("DejaVuSansMono.ttf", size=18)
print(boxes)
for bid, box in enumerate(boxes):
draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4)
anno_text = texts[bid]
draw.rectangle(
[box[0], box[3] - int(font.size * 1.2), box[0] + int((len(anno_text) + 0.8) * font.size * 0.6), box[3]],
outline=colors[bid % len(colors)], fill=colors[bid % len(colors)], width=4)
draw.text([box[0] + int(font.size * 0.2), box[3] - int(font.size * 1.2)], anno_text, font=font,
fill=(255, 255, 255))
return img
'''
inference model
'''
def inference(device, unet, vae, tokenizer, text_encoder, prompt, bboxes, object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_index_step, rand_seed, guidance_scale):
uncond_input = tokenizer(
[""] * 1, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
input_ids = tokenizer(
prompt,
padding="max_length",
truncation=True,
max_length=tokenizer.model_max_length,
return_tensors="pt",
).input_ids[0].unsqueeze(0).to(device)
# text_embeddings = text_encoder(input_ids)[0]
text_embeddings = torch.cat([uncond_embeddings, text_encoder(input_ids)[0]])
# text_embeddings[1, 1, :] = text_embeddings[1, 2, :]
generator = torch.manual_seed(rand_seed) # Seed generator to create the inital latent noise
latents = torch.randn(
(batch_size, 4, 64, 64),
generator=generator,
).to(device)
# noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
noise_scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
# generator = torch.Generator("cuda").manual_seed(1024)
noise_scheduler.set_timesteps(50)
latents = latents * noise_scheduler.init_noise_sigma
loss = torch.tensor(10000)
for index, t in enumerate(noise_scheduler.timesteps):
iteration = 0
while loss.item() / loss_scale > loss_threshold and iteration < max_iter and index < max_index_step:
latents = latents.requires_grad_(True)
# latent_model_input = torch.cat([latents] * 2)
latent_model_input = latents
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \
unet(latent_model_input, t, encoder_hidden_states=text_encoder(input_ids)[0])
# update latents with guidence from gaussian blob
loss = compute_loco_v2(attn_map_integrated_mid, attn_map_integrated_up, bboxes=bboxes,
object_positions=object_positions) * loss_scale
# print(loss.item() / loss_scale)
grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0]
latents = latents - grad_cond
iteration += 1
torch.cuda.empty_cache()
torch.cuda.empty_cache()
with torch.no_grad():
latent_model_input = torch.cat([latents] * 2)
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \
unet(latent_model_input, t, encoder_hidden_states=text_embeddings)
noise_pred = noise_pred.sample
# perform classifier-free guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
torch.cuda.empty_cache()
# Decode image
with torch.no_grad():
# print("decode image")
latents = 1 / 0.18215 * latents
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def get_concat(ims):
if len(ims) == 1:
n_col = 1
else:
n_col = 2
n_row = math.ceil(len(ims) / 2)
dst = Image.new('RGB', (ims[0].width * n_col, ims[0].height * n_row), color="white")
for i, im in enumerate(ims):
row_id = i // n_col
col_id = i % n_col
dst.paste(im, (im.width * col_id, im.height * row_id))
return dst
def click_on_display(language_instruction, grounding_texts, sketch_pad,
loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter,
state):
if 'boxes' not in state:
state['boxes'] = []
boxes = state['boxes']
x = Image.open('./images/dog.png')
gen_images = [gr.Image.update(value=x, visible=True)]
return gen_images + [state]
def generate(unet, vae, tokenizer, text_encoder, language_instruction, grounding_texts, sketch_pad,
loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter,
state):
if 'boxes' not in state:
state['boxes'] = []
boxes = state['boxes']
grounding_texts = [x.strip() for x in grounding_texts.split(';')]
# assert len(boxes) == len(grounding_texts)
if len(boxes) != len(grounding_texts):
if len(boxes) < len(grounding_texts):
raise ValueError("""The number of boxes should be equal to the number of grounding objects.
Number of boxes drawn: {}, number of grounding tokens: {}.
Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts)))
grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
boxes = (np.asarray(boxes) / 512).tolist()
boxes = [[box] for box in boxes]
grounding_instruction = json.dumps({obj: box for obj, box in zip(grounding_texts, boxes)})
language_instruction_list = language_instruction.strip('.').split(' ')
object_positions = []
for obj in grounding_texts:
obj_position = []
for word in obj.split(' '):
obj_first_index = language_instruction_list.index(word) + 1
obj_position.append(obj_first_index)
object_positions.append(obj_position)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen_images = inference(device, unet, vae, tokenizer, text_encoder, language_instruction, boxes, object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_step, rand_seed, guidance_scale)
blank_samples = batch_size % 2 if batch_size > 1 else 0
gen_images = [gr.Image.update(value=x, visible=True) for i, x in enumerate(gen_images)] \
+ [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
+ [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
return gen_images + [state]
def binarize(x):
return (x != 0).astype('uint8') * 255
def sized_center_crop(img, cropx, cropy):
y, x = img.shape[:2]
startx = x // 2 - (cropx // 2)
starty = y // 2 - (cropy // 2)
return img[starty:starty + cropy, startx:startx + cropx]
def sized_center_fill(img, fill, cropx, cropy):
y, x = img.shape[:2]
startx = x // 2 - (cropx // 2)
starty = y // 2 - (cropy // 2)
img[starty:starty + cropy, startx:startx + cropx] = fill
return img
def sized_center_mask(img, cropx, cropy):
y, x = img.shape[:2]
startx = x // 2 - (cropx // 2)
starty = y // 2 - (cropy // 2)
center_region = img[starty:starty + cropy, startx:startx + cropx].copy()
img = (img * 0.2).astype('uint8')
img[starty:starty + cropy, startx:startx + cropx] = center_region
return img
def center_crop(img, HW=None, tgt_size=(512, 512)):
if HW is None:
H, W = img.shape[:2]
HW = min(H, W)
img = sized_center_crop(img, HW, HW)
img = Image.fromarray(img)
img = img.resize(tgt_size)
return np.array(img)
def draw(input, grounding_texts, new_image_trigger, state):
if type(input) == dict:
image = input['image']
mask = input['mask']
else:
mask = input
if mask.ndim == 3:
mask = 255 - mask[..., 0]
image_scale = 1.0
mask = binarize(mask)
if type(mask) != np.ndarray:
mask = np.array(mask)
if mask.sum() == 0:
state = {}
image = None
if 'boxes' not in state:
state['boxes'] = []
if 'masks' not in state or len(state['masks']) == 0:
state['masks'] = []
last_mask = np.zeros_like(mask)
else:
last_mask = state['masks'][-1]
if type(mask) == np.ndarray and mask.size > 1:
diff_mask = mask - last_mask
else:
diff_mask = np.zeros([])
if diff_mask.sum() > 0:
x1x2 = np.where(diff_mask.max(0) != 0)[0]
y1y2 = np.where(diff_mask.max(1) != 0)[0]
y1, y2 = y1y2.min(), y1y2.max()
x1, x2 = x1x2.min(), x1x2.max()
if (x2 - x1 > 5) and (y2 - y1 > 5):
state['masks'].append(mask.copy())
state['boxes'].append((x1, y1, x2, y2))
grounding_texts = [x.strip() for x in grounding_texts.split(';')]
grounding_texts = [x for x in grounding_texts if len(x) > 0]
if len(grounding_texts) < len(state['boxes']):
grounding_texts += [f'Obj. {bid + 1}' for bid in range(len(grounding_texts), len(state['boxes']))]
box_image = draw_box(state['boxes'], grounding_texts, image)
return [box_image, new_image_trigger, image_scale, state]
def clear(task, sketch_pad_trigger, batch_size, state, switch_task=False):
if task != 'Grounded Inpainting':
sketch_pad_trigger = sketch_pad_trigger + 1
blank_samples = batch_size % 2 if batch_size > 1 else 0
out_images = [gr.Image.update(value=None, visible=True) for i in range(batch_size)]
# state = {}
return [None, sketch_pad_trigger, None, 1.0] + out_images + [{}]
def main():
css = """
#img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img
{
height: var(--height) !important;
max-height: var(--height) !important;
min-height: var(--height) !important;
}
#paper-info a {
color:#008AD7;
text-decoration: none;
}
#paper-info a:hover {
cursor: pointer;
text-decoration: none;
}
.tooltip {
color: #555;
position: relative;
display: inline-block;
cursor: pointer;
}
.tooltip .tooltiptext {
visibility: hidden;
width: 400px;
background-color: #555;
color: #fff;
text-align: center;
padding: 5px;
border-radius: 5px;
position: absolute;
z-index: 1; /* Set z-index to 1 */
left: 10px;
top: 100%;
opacity: 0;
transition: opacity 0.3s;
}
.tooltip:hover .tooltiptext {
visibility: visible;
opacity: 1;
z-index: 9999; /* Set a high z-index value when hovering */
}
"""
rescale_js = """
function(x) {
const root = document.querySelector('gradio-app').shadowRoot || document.querySelector('gradio-app');
let image_scale = parseFloat(root.querySelector('#image_scale input').value) || 1.0;
const image_width = root.querySelector('#img2img_image').clientWidth;
const target_height = parseInt(image_width * image_scale);
document.body.style.setProperty('--height', `${target_height}px`);
root.querySelectorAll('button.justify-center.rounded')[0].style.display='none';
root.querySelectorAll('button.justify-center.rounded')[1].style.display='none';
return x;
}
"""
with open('./conf/unet/config.json') as f:
unet_config = json.load(f)
sd_path = "runwayml/stable-diffusion-v1-5"
unet = unet_2d_condition.UNet2DConditionModel(**unet_config).from_pretrained(sd_path,
subfolder="unet")
tokenizer = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet.to(device)
text_encoder.to(device)
vae.to(device)
with Blocks(
css=css,
analytics_enabled=False,
title="LoCo: Locally Constrained Training-free Layout-to-Image Generation",
) as demo:
description = """<p style="text-align: center; font-weight: bold;">
<span style="font-size: 28px">LoCo: Locally Constrained Training-free Layout-to-Image Generation</span>
<br>
<span style="font-size: 18px" id="paper-info">
[<a href=" " target="_blank">Project Page</a>]
[<a href=" " target="_blank">Paper</a>]
[<a href=" " target="_blank">GitHub</a>]
</span>
</p>
"""
gr.HTML(description)
with gr.Column():
language_instruction = gr.Textbox(
label="Text Prompt",
)
grounding_instruction = gr.Textbox(
label="Grounding instruction (Separated by semicolon)",
)
sketch_pad_trigger = gr.Number(value=0, visible=False)
sketch_pad_resize_trigger = gr.Number(value=0, visible=False)
init_white_trigger = gr.Number(value=0, visible=False)
image_scale = gr.Number(value=0, elem_id="image_scale", visible=False)
new_image_trigger = gr.Number(value=0, visible=False)
with gr.Row():
sketch_pad = gr.Paint(label="Sketch Pad", elem_id="img2img_image", source='canvas', shape=(512, 512))
# sketch_pad = gr.Image(source='canvas', tool='sketch', size=(512, 512))
out_imagebox = gr.Image(type="pil", label="Parsed Sketch Pad")
out_gen_1 = gr.Image(type="pil", visible=True, label="Generated Image")
with gr.Row():
clear_btn = gr.Button(value='Clear')
gen_btn = gr.Button(value='Generate')
with gr.Accordion("Advanced Options", open=False):
with gr.Column():
description = """<div class="tooltip">Loss Scale Factor &#9432
<span class="tooltiptext">The scale factor of the backward guidance loss. The larger it is, the better control we get while it sometimes losses fidelity. </span>
</div>
<div class="tooltip">Guidance Scale &#9432
<span class="tooltiptext">The scale factor of classifier-free guidance. </span>
</div>
<div class="tooltip" >Max Iteration per Step &#9432
<span class="tooltiptext">The max iterations of backward guidance in each diffusion inference process.</span>
</div>
<div class="tooltip" >Loss Threshold &#9432
<span class="tooltiptext">The threshold of loss. If the loss computed by cross-attention map is smaller then the threshold, the backward guidance is stopped. </span>
</div>
<div class="tooltip" >Max Step of Backward Guidance &#9432
<span class="tooltiptext">The max steps of backward guidance in diffusion inference process.</span>
</div>
"""
gr.HTML(description)
Loss_scale = gr.Slider(minimum=0, maximum=500, step=5, value=30,label="Loss Scale Factor")
guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Guidance Scale")
batch_size = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="Number of Samples", visible=False)
max_iter = gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Max Iteration per Step")
loss_threshold = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Loss Threshold")
max_step = gr.Slider(minimum=0, maximum=50, step=1, value=10, label="Max Step of Backward Guidance")
rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=445, label="Random Seed")
state = gr.State({})
class Controller:
def __init__(self):
self.calls = 0
self.tracks = 0
self.resizes = 0
self.scales = 0
def init_white(self, init_white_trigger):
self.calls += 1
return np.ones((512, 512), dtype='uint8') * 255, 1.0, init_white_trigger + 1
def change_n_samples(self, n_samples):
blank_samples = n_samples % 2 if n_samples > 1 else 0
return [gr.Image.update(visible=True) for _ in range(n_samples + blank_samples)] \
+ [gr.Image.update(visible=False) for _ in range(4 - n_samples - blank_samples)]
controller = Controller()
demo.load(
lambda x: x + 1,
inputs=sketch_pad_trigger,
outputs=sketch_pad_trigger,
queue=False)
sketch_pad.edit(
draw,
inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
queue=False,
)
grounding_instruction.change(
draw,
inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
queue=False,
)
clear_btn.click(
clear,
inputs=[sketch_pad_trigger, sketch_pad_trigger, batch_size, state],
outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, state],
queue=False)
sketch_pad_trigger.change(
controller.init_white,
inputs=[init_white_trigger],
outputs=[sketch_pad, image_scale, init_white_trigger],
queue=False)
gen_btn.click(
fn=partial(generate, unet, vae, tokenizer, text_encoder,),
inputs=[
language_instruction, grounding_instruction, sketch_pad,
loss_threshold, guidance_scale, batch_size, rand_seed,
max_step,
Loss_scale, max_iter,
state,
],
outputs=[out_gen_1, state],
queue=True
)
sketch_pad_resize_trigger.change(
None,
None,
sketch_pad_resize_trigger,
_js=rescale_js,
queue=False)
init_white_trigger.change(
None,
None,
init_white_trigger,
_js=rescale_js,
queue=False)
with gr.Column():
gr.Examples(
examples=[
[
# "images/input.png",
"A hello kitty toy is playing with a purple ball.",
"hello kitty;ball",
"images/hello_kitty_results.png"
],
],
inputs=[language_instruction, grounding_instruction, out_gen_1],
outputs=None,
fn=None,
cache_examples=False,
)
description = """<p> The source codes of the demo are modified based on the <a href="https://huggingface.co/spaces/gligen/demo/tree/main">GlIGen</a>. Thanks! </p>"""
gr.HTML(description)
demo.queue(concurrency_count=1, api_open=False)
demo.launch(share=False, show_api=False, show_error=True)
if __name__ == '__main__':
main()