akameswa's picture
[bug] fix step size calc
5d8155e verified
raw
history blame contribute delete
No virus
8.51 kB
import io
import os
import torch
import zipfile
import spaces
import numpy as np
import gradio as gr
from PIL import Image
from tqdm.auto import tqdm
from src.util.params import *
from src.util.clip_config import *
import matplotlib.pyplot as plt
@spaces.GPU(enable_queue=True)
def get_text_embeddings(
prompt,
tokenizer=tokenizer,
text_encoder=text_encoder,
torch_device=torch_device,
batch_size=1,
negative_prompt="",
):
text_input = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[negative_prompt] * batch_size,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
with torch.no_grad():
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings
@spaces.GPU(enable_queue=True)
def generate_latents(
seed,
height=imageHeight,
width=imageWidth,
torch_device=torch_device,
unet=unet,
batch_size=1,
):
generator = torch.Generator().manual_seed(int(seed))
latents = torch.randn(
(batch_size, unet.config.in_channels, height // 8, width // 8),
generator=generator,
).to(torch_device)
return latents
@spaces.GPU(enable_queue=True)
def generate_modified_latents(
poke,
seed,
pokeX=None,
pokeY=None,
pokeHeight=None,
pokeWidth=None,
imageHeight=imageHeight,
imageWidth=imageWidth,
):
original_latents = generate_latents(seed, height=imageHeight, width=imageWidth)
if poke:
np.random.seed(seed)
poke_latents = generate_latents(
np.random.randint(0, 100000), height=pokeHeight * 8, width=pokeWidth * 8
)
x_origin = pokeX - pokeWidth // 2
y_origin = pokeY - pokeHeight // 2
modified_latents = original_latents.clone()
modified_latents[
:, :, y_origin : y_origin + pokeHeight, x_origin : x_origin + pokeWidth
] = poke_latents
else:
modified_latents = None
return original_latents, modified_latents
def convert_to_pil_image(image):
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images[0]
@spaces.GPU(enable_queue=True)
def generate_images(
latents,
text_embeddings,
num_inference_steps,
unet=unet,
guidance_scale=guidance_scale,
vae=vae,
scheduler=scheduler,
intermediate=False,
progress=gr.Progress(),
):
scheduler.set_timesteps(num_inference_steps)
latents = latents * scheduler.init_noise_sigma
images = []
i = 1
for t in tqdm(scheduler.timesteps):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
with torch.no_grad():
noise_pred = unet(
latent_model_input, t, encoder_hidden_states=text_embeddings
).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if intermediate:
progress(((1000 - t) / 1000))
Latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(Latents).sample
images.append((convert_to_pil_image(image), "{}".format(i)))
latents = scheduler.step(noise_pred, t, latents).prev_sample
i += 1
if not intermediate:
Latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(Latents).sample
images = convert_to_pil_image(image)
return images
@spaces.GPU(enable_queue=True)
def get_word_embeddings(
prompt, tokenizer=tokenizer, text_encoder=text_encoder, torch_device=torch_device
):
text_input = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).to(torch_device)
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids)[0].reshape(1, -1)
text_embeddings = text_embeddings.cpu().numpy()
return text_embeddings / np.linalg.norm(text_embeddings)
def get_concat_embeddings(names, merge=False):
embeddings = []
for name in names:
embedding = get_word_embeddings(name)
embeddings.append(embedding)
embeddings = np.vstack(embeddings)
if merge:
embeddings = np.average(embeddings, axis=0).reshape(1, -1)
return embeddings
def get_axis_embeddings(A, B):
emb = []
for a, b in zip(A, B):
e = get_word_embeddings(a) - get_word_embeddings(b)
emb.append(e)
emb = np.vstack(emb)
ax = np.average(emb, axis=0).reshape(1, -1)
return ax
def calculate_residual(
axis, axis_names, from_words=None, to_words=None, residual_axis=1
):
axis_indices = [0, 1, 2]
axis_indices.remove(residual_axis)
if axis_names[axis_indices[0]] in axis_combinations:
fembeddings = get_concat_embeddings(
axis_combinations[axis_names[axis_indices[0]]], merge=True
)
else:
axis_combinations[axis_names[axis_indices[0]]] = from_words + to_words
fembeddings = get_concat_embeddings(from_words + to_words, merge=True)
if axis_names[axis_indices[1]] in axis_combinations:
sembeddings = get_concat_embeddings(
axis_combinations[axis_names[axis_indices[1]]], merge=True
)
else:
axis_combinations[axis_names[axis_indices[1]]] = from_words + to_words
sembeddings = get_concat_embeddings(from_words + to_words, merge=True)
fprojections = fembeddings @ axis[axis_indices[0]].T
sprojections = sembeddings @ axis[axis_indices[1]].T
partial_residual = fembeddings - (fprojections.reshape(-1, 1) * fembeddings)
residual = partial_residual - (sprojections.reshape(-1, 1) * sembeddings)
return residual
def calculate_step_size(num_images, start_degree_circular, end_degree_circular):
return (end_degree_circular - start_degree_circular) / (num_images)
def generate_seed_vis(seed):
np.random.seed(seed)
emb = np.random.rand(15)
plt.close()
plt.switch_backend("agg")
plt.figure(figsize=(10, 0.5))
plt.imshow([emb], cmap="viridis")
plt.axis("off")
return plt
def export_as_gif(images, filename, frames_per_second=2, reverse=False):
imgs = [img[0] for img in images]
if reverse:
imgs += imgs[2:-1][::-1]
imgs[0].save(
f"outputs/{filename}",
format="GIF",
save_all=True,
append_images=imgs[1:],
duration=1000 // frames_per_second,
loop=0,
)
def export_as_zip(images, fname, tab_config=None):
if not os.path.exists(f"outputs/{fname}.zip"):
os.makedirs("outputs", exist_ok=True)
with zipfile.ZipFile(f"outputs/{fname}.zip", "w") as img_zip:
if tab_config:
with open("outputs/config.txt", "w") as f:
for key, value in tab_config.items():
f.write(f"{key}: {value}\n")
f.close()
img_zip.write("outputs/config.txt", "config.txt")
for idx, img in enumerate(images):
buff = io.BytesIO()
img[0].save(buff, format="PNG")
buff = buff.getvalue()
max_num = len(images)
num_leading_zeros = len(str(max_num))
img_name = f"{{:0{num_leading_zeros}}}.png"
img_zip.writestr(img_name.format(idx + 1), buff)
def read_html(file_path):
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
return content
__all__ = [
"get_text_embeddings",
"generate_latents",
"generate_modified_latents",
"generate_images",
"get_word_embeddings",
"get_concat_embeddings",
"get_axis_embeddings",
"calculate_residual",
"calculate_step_size",
"generate_seed_vis",
"export_as_gif",
"export_as_zip",
"read_html",
]