DiffusionDemo / src /pipelines /interpolate.py
akameswa's picture
Update src/pipelines/interpolate.py
7ce5929 verified
raw
history blame
No virus
1.71 kB
import torch
import spaces
import gradio as gr
from src.util.base import *
from src.util.params import *
def interpolate_prompts(promptA, promptB, num_interpolation_steps):
text_embeddingsA = get_text_embeddings(promptA)
text_embeddingsB = get_text_embeddings(promptB)
interpolated_embeddings = []
for i in range(num_interpolation_steps):
alpha = i / num_interpolation_steps
interpolated_embedding = torch.lerp(text_embeddingsA, text_embeddingsB, alpha)
interpolated_embeddings.append(interpolated_embedding)
return interpolated_embeddings
@spaces.GPU(enable_queue=True)
def display_interpolate_images(
seed, promptA, promptB, num_inference_steps, num_images, progress=gr.Progress()
):
latents = generate_latents(seed)
num_images = num_images + 2 # add 2 for first and last image
text_embeddings = interpolate_prompts(promptA, promptB, num_images)
images = []
progress(0)
for i in range(num_images):
progress(i / num_images)
image = generate_images(latents, text_embeddings[i], num_inference_steps)
images.append((image, "{}".format(i + 1)))
progress(1, desc="Exporting as gif")
export_as_gif(images, filename="interpolate.gif", reverse=True)
fname = "interpolate"
tab_config = {
"Tab": "Interpolate",
"First Prompt": promptA,
"Second Prompt": promptB,
"Number of Interpolation Steps": num_images,
"Number of Inference Steps per Image": num_inference_steps,
"Seed": seed,
}
export_as_zip(images, fname, tab_config)
return images, "outputs/interpolate.gif", f"outputs/{fname}.zip"
__all__ = ["display_interpolate_images"]