import numpy as np import torch import torch.nn as nn from tqdm import trange from torchvision.transforms import Compose class Diffusion(nn.Module): def __init__( self, nn_backbone, device, n_timesteps=1000, in_channels=3, image_size=128, out_channels=6, motion_transforms=None): super(Diffusion, self).__init__() self.nn_backbone = nn_backbone self.n_timesteps = n_timesteps self.in_channels = in_channels self.out_channels = out_channels self.x_shape = (image_size, image_size) self.device = device self.motion_transforms = motion_transforms if motion_transforms else Compose([]) self.timesteps = torch.arange(n_timesteps) self.beta = self.get_beta_schedule() self.set_params() self.device = device def sample(self, x_cond, audio_emb, n_audio_motion_embs=2, n_motion_frames=2, motion_channels=3): with torch.no_grad(): n_frames = audio_emb.shape[1] xT = torch.randn(x_cond.shape[0], n_frames, self.in_channels, self.x_shape[0], self.x_shape[1]).to(x_cond.device) audio_ids = [0] * n_audio_motion_embs for i in range(n_audio_motion_embs + 1): audio_ids += [i] motion_frames = [self.motion_transforms(x_cond) for _ in range(n_motion_frames)] motion_frames = torch.cat(motion_frames, dim=1) samples = [] for i in trange(n_frames, desc=f'Sampling'): sample_frame = self.sample_loop(xT[:, i].to(x_cond.device), x_cond, motion_frames, audio_emb[:, audio_ids]) samples.append(sample_frame.unsqueeze(1)) motion_frames = torch.cat([motion_frames[:, motion_channels:, :], self.motion_transforms(sample_frame)], dim=1) audio_ids = audio_ids[1:] + [min(i + n_audio_motion_embs + 1, n_frames - 1)] return torch.cat(samples, dim=1) def sample_loop(self, xT, x_cond, motion_frames, audio_emb): xt = xT for i, t in reversed(list(enumerate(self.timesteps))): timesteps = torch.tensor([t] * xT.shape[0]).to(xT.device) timesteps_ids = torch.tensor([i] * xT.shape[0]).to(xT.device) nn_out = self.nn_backbone(xt, timesteps, x_cond, motion_frames=motion_frames, audio_emb=audio_emb) mean, logvar = self.get_p_params(xt, timesteps_ids, nn_out) noise = torch.randn_like(xt) if t > 0 else torch.zeros_like(xt) xt = mean + noise * torch.exp(logvar / 2) return xt def get_p_params(self, xt, timesteps, nn_out): if self.in_channels == self.out_channels: eps_pred = nn_out p_logvar = self.expand(torch.log(self.beta[timesteps])) else: eps_pred, nu = nn_out.chunk(2, 1) nu = (nu + 1) / 2 p_logvar = nu * self.expand(torch.log(self.beta[timesteps])) + (1 - nu) * self.expand(self.log_beta_tilde_clipped[timesteps]) p_mean, _ = self.get_q_params(xt, timesteps, eps_pred=eps_pred) return p_mean, p_logvar def get_q_params(self, xt, timesteps, eps_pred=None, x0=None): if x0 is None: # predict x0 from xt and eps_pred coef1_x0 = self.expand(self.coef1_x0[timesteps]) coef2_x0 = self.expand(self.coef2_x0[timesteps]) x0 = coef1_x0 * xt - coef2_x0 * eps_pred x0 = x0.clamp(-1, 1) # q(x_{t-1} | x_t, x_0) coef1_q = self.expand(self.coef1_q[timesteps]) coef2_q = self.expand(self.coef2_q[timesteps]) q_mean = coef1_q * x0 + coef2_q * xt q_logvar = self.expand(self.log_beta_tilde_clipped[timesteps]) return q_mean, q_logvar def get_beta_schedule(self, max_beta=0.999): alpha_bar = lambda t: np.cos((t + 0.008) / 1.008 * np.pi / 2) ** 2 betas = [] for i in range(self.n_timesteps): t1 = i / self.n_timesteps t2 = (i + 1) / self.n_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) return torch.tensor(betas).float() def set_params(self): self.alpha = 1 - self.beta self.alpha_bar = torch.cumprod(self.alpha, dim=0) self.alpha_bar_prev = torch.cat([torch.ones(1,), self.alpha_bar[:-1]]) self.beta_tilde = self.beta * (1.0 - self.alpha_bar_prev) / (1.0 - self.alpha_bar) self.log_beta_tilde_clipped = torch.log(torch.cat([self.beta_tilde[1, None], self.beta_tilde[1:]])) # to caluclate x0 from eps_pred self.coef1_x0 = torch.sqrt(1.0 / self.alpha_bar) self.coef2_x0 = torch.sqrt(1.0 / self.alpha_bar - 1) # for q(x_{t-1} | x_t, x_0) self.coef1_q = self.beta * torch.sqrt(self.alpha_bar_prev) / (1.0 - self.alpha_bar) self.coef2_q = (1.0 - self.alpha_bar_prev) * torch.sqrt(self.alpha) / (1.0 - self.alpha_bar) def space(self, n_timesteps_new): # change parameters for spaced timesteps during sampling self.timesteps = self.space_timesteps(self.n_timesteps, n_timesteps_new) self.n_timesteps = n_timesteps_new self.beta = self.get_spaced_beta() self.set_params() def space_timesteps(self, n_timesteps, target_timesteps): all_steps = [] frac_stride = (n_timesteps - 1) / (target_timesteps - 1) cur_idx = 0.0 taken_steps = [] for _ in range(target_timesteps): taken_steps.append(round(cur_idx)) cur_idx += frac_stride all_steps += taken_steps return all_steps def get_spaced_beta(self): last_alpha_cumprod = 1.0 new_beta = [] for i, alpha_cumprod in enumerate(self.alpha_bar): if i in self.timesteps: new_beta.append(1 - alpha_cumprod / last_alpha_cumprod) last_alpha_cumprod = alpha_cumprod return torch.tensor(new_beta) def expand(self, arr, dim=4): while arr.dim() < dim: arr = arr[:, None] return arr.to(self.device)