kyleleey
first commit
98a77e0
raw
history blame
No virus
12.5 kB
import os
os.environ['HUGGINGFACE_HUB_CACHE'] = '/viscam/u/zzli'
os.environ['HF_HOME'] = '/viscam/u/zzli'
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.models.embeddings import TimestepEmbedding
from diffusers.utils.import_utils import is_xformers_available
# Suppress partial model loading warning
logging.set_verbosity_error()
import gc
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import tinycudann as tcnn
from video3d.diffusion.sd import StableDiffusion
from torch.cuda.amp import custom_bwd, custom_fwd
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def cleanup():
gc.collect()
torch.cuda.empty_cache()
tcnn.free_temporary_memory()
class StableDiffusion_VSD(StableDiffusion):
def __init__(self, device, sd_version='2.1', hf_key=None, torch_dtype=torch.float32, lora_n_timestamp_samples=1):
super().__init__(device, sd_version=sd_version, hf_key=hf_key, torch_dtype=torch_dtype)
# self.device = device
# self.sd_version = sd_version
# self.torch_dtype = torch_dtype
if hf_key is not None:
print(f'[INFO] using hugging face custom model key: {hf_key}')
model_key = hf_key
elif self.sd_version == '2.1':
model_key = "stabilityai/stable-diffusion-2-1-base"
elif self.sd_version == '2.0':
model_key = "stabilityai/stable-diffusion-2-base"
elif self.sd_version == '1.5':
model_key = "runwayml/stable-diffusion-v1-5"
else:
raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')
# # Create model
# self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae", torch_dtype=torch_dtype).to(self.device)
# self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
# self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
# self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", torch_dtype=torch_dtype).to(self.device)
# self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
# # self.scheduler = PNDMScheduler.from_pretrained(model_key, subfolder="scheduler")
# self.num_train_timesteps = self.scheduler.config.num_train_timesteps
# self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
print(f'[INFO] loading stable diffusion VSD modules...')
self.unet_lora = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", torch_dtype=torch_dtype).to(self.device)
cleanup()
for p in self.vae.parameters():
p.requires_grad_(False)
for p in self.text_encoder.parameters():
p.requires_grad_(False)
for p in self.unet.parameters():
p.requires_grad_(False)
for p in self.unet_lora.parameters():
p.requires_grad_(False)
# set up LoRA layers
lora_attn_procs = {}
for name in self.unet_lora.attn_processors.keys():
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else self.unet_lora.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = self.unet_lora.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(self.unet_lora.config.block_out_channels))[
block_id
]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = self.unet_lora.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
self.unet_lora.set_attn_processor(lora_attn_procs)
self.lora_layers = AttnProcsLayers(self.unet_lora.attn_processors).to(
self.device
)
self.lora_layers._load_state_dict_pre_hooks.clear()
self.lora_layers._state_dict_hooks.clear()
self.lora_n_timestamp_samples = lora_n_timestamp_samples
self.scheduler_lora = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
print(f'[INFO] loaded stable diffusion VSD modules!')
def train_lora(
self,
latents,
text_embeddings,
camera_condition
):
B = latents.shape[0]
lora_n_timestamp_samples = self.lora_n_timestamp_samples
latents = latents.detach().repeat(lora_n_timestamp_samples, 1, 1, 1)
t = torch.randint(
int(self.num_train_timesteps * 0.0),
int(self.num_train_timesteps * 1.0),
[B * lora_n_timestamp_samples],
dtype=torch.long,
device=self.device,
)
noise = torch.randn_like(latents)
noisy_latents = self.scheduler_lora.add_noise(latents, noise, t)
if self.scheduler_lora.config.prediction_type == "epsilon":
target = noise
elif self.scheduler_lora.config.prediction_type == "v_prediction":
target = self.scheduler_lora.get_velocity(latents, noise, t)
else:
raise ValueError(
f"Unknown prediction type {self.scheduler_lora.config.prediction_type}"
)
# use view-independent text embeddings in LoRA
_, text_embeddings_cond = text_embeddings.chunk(2)
if random.random() < 0.1:
camera_condition = torch.zeros_like(camera_condition)
noise_pred = self.unet_lora(
noisy_latents,
t,
encoder_hidden_states=text_embeddings_cond.repeat(
lora_n_timestamp_samples, 1, 1
),
class_labels=camera_condition.reshape(B, -1).repeat(
lora_n_timestamp_samples, 1
),
cross_attention_kwargs={"scale": 1.0}
).sample
loss_lora = 0.5 * F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
return loss_lora
def train_step(
self,
text_embeddings,
text_embeddings_vd,
pred_rgb,
camera_condition,
im_features,
guidance_scale=7.5,
guidance_scale_lora=7.5,
loss_weight=1.0,
min_step_pct=0.02,
max_step_pct=0.98,
return_aux=False
):
pred_rgb = pred_rgb.to(self.torch_dtype)
text_embeddings = text_embeddings.to(self.torch_dtype)
text_embeddings_vd = text_embeddings_vd.to(self.torch_dtype)
camera_condition = camera_condition.to(self.torch_dtype)
im_features = im_features.to(self.torch_dtype)
# condition_label = camera_condition
condition_label = im_features
b = pred_rgb.shape[0]
# interp to 512x512 to be fed into vae.
# _t = time.time()
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
# torch.cuda.synchronize(); print(f'[TIME] guiding: interp {time.time() - _t:.4f}s')
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
min_step = int(self.num_train_timesteps * min_step_pct)
max_step = int(self.num_train_timesteps * max_step_pct)
t = torch.randint(min_step, max_step + 1, [b], dtype=torch.long, device=self.device)
# encode image into latents with vae, requires grad!
# _t = time.time()
latents = self.encode_imgs(pred_rgb_512)
# torch.cuda.synchronize(); print(f'[TIME] guiding: vae enc {time.time() - _t:.4f}s')
# predict the noise residual with unet, NO grad!
# _t = time.time()
with torch.no_grad():
# add noise
noise = torch.randn_like(latents)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
# pred noise
latent_model_input = torch.cat([latents_noisy] * 2)
# disable unet class embedding here
cls_embedding = self.unet.class_embedding
self.unet.class_embedding = None
cross_attention_kwargs = None
noise_pred_pretrain = self.unet(
latent_model_input,
torch.cat([t, t]),
encoder_hidden_states=text_embeddings_vd,
class_labels=None,
cross_attention_kwargs=cross_attention_kwargs
).sample
self.unet.class_embedding = cls_embedding
# use view-independent text embeddings in LoRA
_, text_embeddings_cond = text_embeddings.chunk(2)
noise_pred_est = self.unet_lora(
latent_model_input,
torch.cat([t, t]),
encoder_hidden_states=torch.cat([text_embeddings_cond] * 2),
class_labels=torch.cat(
[
condition_label.reshape(b, -1),
torch.zeros_like(condition_label.reshape(b, -1)),
],
dim=0,
),
cross_attention_kwargs={"scale": 1.0},
).sample
noise_pred_pretrain_uncond, noise_pred_pretrain_text = noise_pred_pretrain.chunk(2)
noise_pred_pretrain = noise_pred_pretrain_uncond + guidance_scale * (
noise_pred_pretrain_text - noise_pred_pretrain_uncond
)
assert self.scheduler.config.prediction_type == "epsilon"
if self.scheduler_lora.config.prediction_type == "v_prediction":
alphas_cumprod = self.scheduler_lora.alphas_cumprod.to(
device=latents_noisy.device, dtype=latents_noisy.dtype
)
alpha_t = alphas_cumprod[t] ** 0.5
sigma_t = (1 - alphas_cumprod[t]) ** 0.5
noise_pred_est = latent_model_input * torch.cat([sigma_t] * 2, dim=0).reshape(
-1, 1, 1, 1
) + noise_pred_est * torch.cat([alpha_t] * 2, dim=0).reshape(-1, 1, 1, 1)
noise_pred_est_uncond, noise_pred_est_camera = noise_pred_est.chunk(2)
noise_pred_est = noise_pred_est_uncond + guidance_scale_lora * (
noise_pred_est_camera - noise_pred_est_uncond
)
# w(t), sigma_t^2
w = (1 - self.alphas[t])
# w = self.alphas[t] ** 0.5 * (1 - self.alphas[t])
grad = loss_weight * w[:, None, None, None] * (noise_pred_pretrain - noise_pred_est)
grad = torch.nan_to_num(grad)
targets = (latents - grad).detach()
loss_vsd = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]
loss_lora = self.train_lora(latents, text_embeddings, condition_label)
loss = {
'loss_vsd': loss_vsd,
'loss_lora': loss_lora
}
if return_aux:
aux = {'grad': grad, 't': t, 'w': w}
return loss, aux
else:
return loss
if __name__ == '__main__':
import argparse
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument('prompt', type=str)
parser.add_argument('--negative', default='', type=str)
parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key")
parser.add_argument('-H', type=int, default=512)
parser.add_argument('-W', type=int, default=512)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--steps', type=int, default=50)
opt = parser.parse_args()
seed_everything(opt.seed)
device = torch.device('cuda')
sd = StableDiffusion_VSD(device, opt.sd_version, opt.hf_key)
imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
# visualize image
plt.imshow(imgs[0])
plt.show()
plt.savefig(f'{opt.prompt}.png')