import torch from src.utils.iimage import IImage from pytorch_lightning import seed_everything from tqdm import tqdm from src.smplfusion import share, router, attentionpatch, transformerpatch from src.smplfusion.patches.attentionpatch import painta from src.utils import tokenize, scores verbose = False def init_painta(token_idx): # Initialize painta router.attention_forward = attentionpatch.painta.forward router.basic_transformer_forward = transformerpatch.painta.forward painta.painta_on = True painta.painta_res = [16, 32] painta.token_idx = token_idx def init_guidance(): # Setup model for guidance only! router.attention_forward = attentionpatch.default.forward_and_save router.basic_transformer_forward = transformerpatch.default.forward def run( ddim, method, prompt, image, mask, seed=0, eta=0.1, negative_prompt='', positive_prompt='', num_steps=50, guidance_scale=7.5 ): image = image.padx(64) mask = mask.dilate(1).alpha().padx(64) full_prompt = prompt if positive_prompt != '': full_prompt = f'{prompt}, {positive_prompt}' dt = 1000 // num_steps # Text condition context = ddim.encoder.encode([negative_prompt, full_prompt]) token_idx = list(range(1, tokenize(prompt).index(''))) token_idx += [tokenize(full_prompt).index('')] # Initialize painta if 'painta' in method: init_painta(token_idx) else: init_guidance() # Image condition unet_condition = ddim.get_inpainting_condition(image, mask) share.set_mask(mask) dtype = unet_condition.dtype # Starting latent seed_everything(seed) zt = torch.randn((1,4) + unet_condition.shape[2:]).cuda().to(dtype) # Setup unet for guidance ddim.unet.requires_grad_(True) pbar = tqdm(range(999, 0, -dt)) if verbose else range(999, 0, -dt) for timestep in share.DDIMIterator(pbar): if 'painta' in method and share.timestep <= 500: init_guidance() zt = zt.detach() zt.requires_grad = True # Reset storage share._crossattn_similarity_res16 = [] # Run the model _zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1) with torch.autocast('cuda'): eps_uncond, eps = ddim.unet( torch.cat([_zt, _zt]).to(dtype), timesteps = torch.tensor([timestep, timestep]).cuda(), context = context ).detach().chunk(2) # Unconditional guidance eps = (eps_uncond + guidance_scale * (eps - eps_uncond)) z0 = (zt - share.schedule.sqrt_one_minus_alphas[timestep] * eps) / share.schedule.sqrt_alphas[timestep] # Gradient Computation score = scores.bce(share._crossattn_similarity_res16, share.mask16, token_idx = token_idx) score.backward() grad = zt.grad.detach() ddim.unet.zero_grad() # DDIM Step with torch.no_grad(): sigma = share.schedule.sigma(share.timestep, dt) grad /= grad.std() zt = share.schedule.sqrt_alphas[share.timestep - dt] * z0 + \ torch.sqrt(1 - share.schedule.alphas[share.timestep - dt] - (eta * sigma) ** 2) * eps + \ (eta * sigma) * grad with torch.no_grad(): output_image = IImage(ddim.vae.decode(z0 / ddim.config.scale_factor)) return output_image