import random import numpy as np import gradio as gr import matplotlib.pyplot as plt from diffusers import StableDiffusionPipeline import base64 from io import BytesIO import plotly.express as px from src.util.base import * from src.util.params import * from src.util.clip_config import * age = get_axis_embeddings(young, old) gender = get_axis_embeddings(masculine, feminine) royalty = get_axis_embeddings(common, elite) images = [] for example in examples: image = pipe( prompt=example, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ).images[0] buffer = BytesIO() image.save(buffer, format="JPEG") encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8") images.append("data:image/jpeg;base64, " + encoded_image) axis = np.vstack([gender, royalty, age]) axis[1] = calculate_residual(axis, axis_names) coords = get_concat_embeddings(examples) @ axis.T coords[:, 1] = 5 * (1.0 - coords[:, 1]) def update_fig(): global coords, examples, fig fig.data[0].x = coords[:, 0] fig.data[0].y = coords[:, 1] fig.data[0].z = coords[:, 2] fig.data[0].text = examples return f""" """ def add_word(new_example): global coords, images, examples new_coord = get_concat_embeddings([new_example]) @ axis.T new_coord[:, 1] = 5 * (1.0 - new_coord[:, 1]) coords = np.vstack([coords, new_coord]) image = pipe( prompt=new_example, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ).images[0] buffer = BytesIO() image.save(buffer, format="JPEG") encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8") images.append("data:image/jpeg;base64, " + encoded_image) examples.append(new_example) return update_fig() def remove_word(new_example): global coords, images, examples examplesMap = {example: index for index, example in enumerate(examples)} index = examplesMap[new_example] coords = np.delete(coords, index, 0) images.pop(index) examples.pop(index) return update_fig() def add_rem_word(new_examples): global examples new_examples = new_examples.replace(",", " ").split() for new_example in new_examples: if new_example in examples: remove_word(new_example) gr.Info("Removed {}".format(new_example)) else: tokens = tokenizer.encode(new_example) if len(tokens) != 3: gr.Warning(f"{new_example} not found in embeddings") else: add_word(new_example) gr.Info("Added {}".format(new_example)) return update_fig() def set_axis(axis_name, which_axis, from_words, to_words): global coords, examples, fig, axis_names if axis_name != "residual": from_words, to_words = ( from_words.replace(",", " ").split(), to_words.replace(",", " ").split(), ) axis_emb = get_axis_embeddings(from_words, to_words) axis[axisMap[which_axis]] = axis_emb axis_names[axisMap[which_axis]] = axis_name for i, name in enumerate(axis_names): if name == "residual": axis[i] = calculate_residual(axis, axis_names, from_words, to_words, i) axis_names[i] = "residual" else: residual = calculate_residual( axis, axis_names, residual_axis=axisMap[which_axis] ) axis[axisMap[which_axis]] = residual axis_names[axisMap[which_axis]] = axis_name coords = get_concat_embeddings(examples) @ axis.T coords[:, 1] = 5 * (1.0 - coords[:, 1]) fig.update_layout( scene=dict( xaxis_title=axis_names[0], yaxis_title=axis_names[1], zaxis_title=axis_names[2], ) ) return update_fig() def change_word(examples): examples = examples.replace(",", " ").split() for example in examples: remove_word(example) add_word(example) gr.Info("Changed image for {}".format(example)) return update_fig() def clear_words(): while examples: remove_word(examples[-1]) return update_fig() def generate_word_emb_vis(prompt): buf = BytesIO() emb = get_word_embeddings(prompt).reshape(77, 768)[1] plt.imsave(buf, [emb], cmap="inferno") img = "data:image/jpeg;base64, " + base64.b64encode(buf.getvalue()).decode("utf-8") return img fig = px.scatter_3d( x=coords[:, 0], y=coords[:, 1], z=coords[:, 2], labels={ "x": axis_names[0], "y": axis_names[1], "z": axis_names[2], }, text=examples, height=750, ) fig.update_layout( margin=dict(l=0, r=0, b=0, t=0), scene_camera=dict(eye=dict(x=2, y=2, z=0.1)) ) fig.update_traces(hoverinfo="none", hovertemplate=None) __all__ = [ "fig", "update_fig", "coords", "images", "examples", "add_word", "remove_word", "add_rem_word", "change_word", "clear_words", "generate_word_emb_vis", "set_axis", "axis", ]