import io import os import torch import zipfile import spaces import numpy as np import gradio as gr from PIL import Image from tqdm.auto import tqdm from src.util.params import * from src.util.clip_config import * import matplotlib.pyplot as plt @spaces.GPU(enable_queue=True) def get_text_embeddings( prompt, tokenizer=tokenizer, text_encoder=text_encoder, torch_device=torch_device, batch_size=1, negative_prompt="", ): text_input = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) with torch.no_grad(): text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] max_length = text_input.input_ids.shape[-1] uncond_input = tokenizer( [negative_prompt] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt", ) with torch.no_grad(): uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) return text_embeddings @spaces.GPU(enable_queue=True) def generate_latents( seed, height=imageHeight, width=imageWidth, torch_device=torch_device, unet=unet, batch_size=1, ): generator = torch.Generator().manual_seed(int(seed)) latents = torch.randn( (batch_size, unet.config.in_channels, height // 8, width // 8), generator=generator, ).to(torch_device) return latents @spaces.GPU(enable_queue=True) def generate_modified_latents( poke, seed, pokeX=None, pokeY=None, pokeHeight=None, pokeWidth=None, imageHeight=imageHeight, imageWidth=imageWidth, ): original_latents = generate_latents(seed, height=imageHeight, width=imageWidth) if poke: np.random.seed(seed) poke_latents = generate_latents( np.random.randint(0, 100000), height=pokeHeight * 8, width=pokeWidth * 8 ) x_origin = pokeX - pokeWidth // 2 y_origin = pokeY - pokeHeight // 2 modified_latents = original_latents.clone() modified_latents[ :, :, y_origin : y_origin + pokeHeight, x_origin : x_origin + pokeWidth ] = poke_latents else: modified_latents = None return original_latents, modified_latents def convert_to_pil_image(image): image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().permute(0, 2, 3, 1).numpy() images = (image * 255).round().astype("uint8") pil_images = [Image.fromarray(image) for image in images] return pil_images[0] @spaces.GPU(enable_queue=True) def generate_images( latents, text_embeddings, num_inference_steps, unet=unet, guidance_scale=guidance_scale, vae=vae, scheduler=scheduler, intermediate=False, progress=gr.Progress(), ): scheduler.set_timesteps(num_inference_steps) latents = latents * scheduler.init_noise_sigma images = [] i = 1 for t in tqdm(scheduler.timesteps): latent_model_input = torch.cat([latents] * 2) latent_model_input = scheduler.scale_model_input(latent_model_input, t) with torch.no_grad(): noise_pred = unet( latent_model_input, t, encoder_hidden_states=text_embeddings ).sample noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond ) if intermediate: progress(((1000 - t) / 1000)) Latents = 1 / 0.18215 * latents with torch.no_grad(): image = vae.decode(Latents).sample images.append((convert_to_pil_image(image), "{}".format(i))) latents = scheduler.step(noise_pred, t, latents).prev_sample i += 1 if not intermediate: Latents = 1 / 0.18215 * latents with torch.no_grad(): image = vae.decode(Latents).sample images = convert_to_pil_image(image) return images @spaces.GPU(enable_queue=True) def get_word_embeddings( prompt, tokenizer=tokenizer, text_encoder=text_encoder, torch_device=torch_device ): text_input = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ).to(torch_device) with torch.no_grad(): text_embeddings = text_encoder(text_input.input_ids)[0].reshape(1, -1) text_embeddings = text_embeddings.cpu().numpy() return text_embeddings / np.linalg.norm(text_embeddings) def get_concat_embeddings(names, merge=False): embeddings = [] for name in names: embedding = get_word_embeddings(name) embeddings.append(embedding) embeddings = np.vstack(embeddings) if merge: embeddings = np.average(embeddings, axis=0).reshape(1, -1) return embeddings def get_axis_embeddings(A, B): emb = [] for a, b in zip(A, B): e = get_word_embeddings(a) - get_word_embeddings(b) emb.append(e) emb = np.vstack(emb) ax = np.average(emb, axis=0).reshape(1, -1) return ax def calculate_residual( axis, axis_names, from_words=None, to_words=None, residual_axis=1 ): axis_indices = [0, 1, 2] axis_indices.remove(residual_axis) if axis_names[axis_indices[0]] in axis_combinations: fembeddings = get_concat_embeddings( axis_combinations[axis_names[axis_indices[0]]], merge=True ) else: axis_combinations[axis_names[axis_indices[0]]] = from_words + to_words fembeddings = get_concat_embeddings(from_words + to_words, merge=True) if axis_names[axis_indices[1]] in axis_combinations: sembeddings = get_concat_embeddings( axis_combinations[axis_names[axis_indices[1]]], merge=True ) else: axis_combinations[axis_names[axis_indices[1]]] = from_words + to_words sembeddings = get_concat_embeddings(from_words + to_words, merge=True) fprojections = fembeddings @ axis[axis_indices[0]].T sprojections = sembeddings @ axis[axis_indices[1]].T partial_residual = fembeddings - (fprojections.reshape(-1, 1) * fembeddings) residual = partial_residual - (sprojections.reshape(-1, 1) * sembeddings) return residual def calculate_step_size(num_images, start_degree_circular, end_degree_circular): return (end_degree_circular - start_degree_circular) / (num_images) def generate_seed_vis(seed): np.random.seed(seed) emb = np.random.rand(15) plt.close() plt.switch_backend("agg") plt.figure(figsize=(10, 0.5)) plt.imshow([emb], cmap="viridis") plt.axis("off") return plt def export_as_gif(images, filename, frames_per_second=2, reverse=False): imgs = [img[0] for img in images] if reverse: imgs += imgs[2:-1][::-1] imgs[0].save( f"outputs/{filename}", format="GIF", save_all=True, append_images=imgs[1:], duration=1000 // frames_per_second, loop=0, ) def export_as_zip(images, fname, tab_config=None): if not os.path.exists(f"outputs/{fname}.zip"): os.makedirs("outputs", exist_ok=True) with zipfile.ZipFile(f"outputs/{fname}.zip", "w") as img_zip: if tab_config: with open("outputs/config.txt", "w") as f: for key, value in tab_config.items(): f.write(f"{key}: {value}\n") f.close() img_zip.write("outputs/config.txt", "config.txt") for idx, img in enumerate(images): buff = io.BytesIO() img[0].save(buff, format="PNG") buff = buff.getvalue() max_num = len(images) num_leading_zeros = len(str(max_num)) img_name = f"{{:0{num_leading_zeros}}}.png" img_zip.writestr(img_name.format(idx + 1), buff) def read_html(file_path): with open(file_path, "r", encoding="utf-8") as f: content = f.read() return content __all__ = [ "get_text_embeddings", "generate_latents", "generate_modified_latents", "generate_images", "get_word_embeddings", "get_concat_embeddings", "get_axis_embeddings", "calculate_residual", "calculate_step_size", "generate_seed_vis", "export_as_gif", "export_as_zip", "read_html", ]