DiffusionDemo / src /pipelines /embeddings.py
akameswa's picture
Update src/pipelines/embeddings.py
0f8b569 verified
raw
history blame contribute delete
No virus
5.55 kB
import random
import spaces
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
from diffusers import StableDiffusionPipeline
import base64
from io import BytesIO
import plotly.express as px
from src.util.base import *
from src.util.params import *
from src.util.clip_config import *
age = get_axis_embeddings(young, old)
gender = get_axis_embeddings(masculine, feminine)
royalty = get_axis_embeddings(common, elite)
@spaces.GPU()
def generate_examples(examples=examples, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale):
images = []
for example in examples:
image = pipe(
prompt=example,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
buffer = BytesIO()
image.save(buffer, format="JPEG")
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
images.append("data:image/jpeg;base64, " + encoded_image)
return images
images = generate_examples()
axis = np.vstack([gender, royalty, age])
axis[1] = calculate_residual(axis, axis_names)
coords = get_concat_embeddings(examples) @ axis.T
coords[:, 1] = 5 * (1.0 - coords[:, 1])
def update_fig():
global coords, examples, fig
fig.data[0].x = coords[:, 0]
fig.data[0].y = coords[:, 1]
fig.data[0].z = coords[:, 2]
fig.data[0].text = examples
return f"""
<script>
document.getElementById("html").src += "?rand={random.random()}"
</script>
<iframe id="html" src={dash_tunnel} style="width:100%; height:725px;"></iframe>
"""
def add_word(new_example):
global coords, images, examples
new_coord = get_concat_embeddings([new_example]) @ axis.T
new_coord[:, 1] = 5 * (1.0 - new_coord[:, 1])
coords = np.vstack([coords, new_coord])
image = pipe(
prompt=new_example,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
buffer = BytesIO()
image.save(buffer, format="JPEG")
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
images.append("data:image/jpeg;base64, " + encoded_image)
examples.append(new_example)
return update_fig()
def remove_word(new_example):
global coords, images, examples
examplesMap = {example: index for index, example in enumerate(examples)}
index = examplesMap[new_example]
coords = np.delete(coords, index, 0)
images.pop(index)
examples.pop(index)
return update_fig()
def add_rem_word(new_examples):
global examples
new_examples = new_examples.replace(",", " ").split()
for new_example in new_examples:
if new_example in examples:
remove_word(new_example)
gr.Info("Removed {}".format(new_example))
else:
tokens = tokenizer.encode(new_example)
if len(tokens) != 3:
gr.Warning(f"{new_example} not found in embeddings")
else:
add_word(new_example)
gr.Info("Added {}".format(new_example))
return update_fig()
def set_axis(axis_name, which_axis, from_words, to_words):
global coords, examples, fig, axis_names
if axis_name != "residual":
from_words, to_words = (
from_words.replace(",", " ").split(),
to_words.replace(",", " ").split(),
)
axis_emb = get_axis_embeddings(from_words, to_words)
axis[axisMap[which_axis]] = axis_emb
axis_names[axisMap[which_axis]] = axis_name
for i, name in enumerate(axis_names):
if name == "residual":
axis[i] = calculate_residual(axis, axis_names, from_words, to_words, i)
axis_names[i] = "residual"
else:
residual = calculate_residual(
axis, axis_names, residual_axis=axisMap[which_axis]
)
axis[axisMap[which_axis]] = residual
axis_names[axisMap[which_axis]] = axis_name
coords = get_concat_embeddings(examples) @ axis.T
coords[:, 1] = 5 * (1.0 - coords[:, 1])
fig.update_layout(
scene=dict(
xaxis_title=axis_names[0],
yaxis_title=axis_names[1],
zaxis_title=axis_names[2],
)
)
return update_fig()
def change_word(examples):
examples = examples.replace(",", " ").split()
for example in examples:
remove_word(example)
add_word(example)
gr.Info("Changed image for {}".format(example))
return update_fig()
def clear_words():
while examples:
remove_word(examples[-1])
return update_fig()
def generate_word_emb_vis(prompt):
buf = BytesIO()
emb = get_word_embeddings(prompt).reshape(77, 768)[1]
plt.imsave(buf, [emb], cmap="inferno")
img = "data:image/jpeg;base64, " + base64.b64encode(buf.getvalue()).decode("utf-8")
return img
fig = px.scatter_3d(
x=coords[:, 0],
y=coords[:, 1],
z=coords[:, 2],
labels={
"x": axis_names[0],
"y": axis_names[1],
"z": axis_names[2],
},
text=examples,
height=750,
)
fig.update_layout(
margin=dict(l=0, r=0, b=0, t=0), scene_camera=dict(eye=dict(x=2, y=2, z=0.1))
)
fig.update_traces(hoverinfo="none", hovertemplate=None)
__all__ = [
"fig",
"update_fig",
"coords",
"images",
"examples",
"add_word",
"remove_word",
"add_rem_word",
"change_word",
"clear_words",
"generate_word_emb_vis",
"set_axis",
"axis",
]