import torch from torch import nn, Tensor from torch.optim import Optimizer from .blocks import UpResConvBlock, DownResConvBlock import lightning as L from auraloss.freq import MultiResolutionSTFTLoss class Encoder(nn.Module): def __init__(self, in_channels: int, in_features: int, out_features: int, channels: list = None, ) -> None: super(Encoder, self).__init__() assert in_features % 2**len(channels) == 0, f"in_features ({in_features}) must be a multiple of downscale factor ({2**len(channels)})" modules = [ nn.Conv1d(in_channels, channels[0], 1), nn.GELU() ] for in_channel, out_channel in zip(channels, channels[1:]+[channels[-1]]): modules += [ DownResConvBlock(in_channel, out_channel, 1), ] n_features = int(in_features*.5**len(channels)) modules += [ nn.Flatten(), nn.Linear(n_features*channels[-1], 2*out_features) ] self.net = nn.Sequential(*modules) def forward(self, x): mean, logvar = self.net(x).chunk(2, dim=1) return mean, logvar class Decoder(nn.Module): def __init__(self, out_channels: int, in_features: int, out_features: int, channels: list = None, ) -> None: super(Decoder, self).__init__() n_features = int(out_features/2**len(channels)) modules = [ nn.Linear(in_features, n_features*channels[0]), nn.Unflatten(-1, (channels[0], n_features)) ] for in_channel, out_channel in zip(channels, channels[1:]+[channels[-1]]): modules += [ UpResConvBlock(in_channel, out_channel, 1), ] modules += [ nn.Conv1d(channels[-1], out_channels, 1), nn.GELU() ] self.net = nn.Sequential(*modules) def forward(self, x): x = torch.tanh(self.net(x)) return x class VAE(L.LightningModule): def __init__(self, io_channels: int, io_features: int, latent_features: int, channels: list, learning_rate: float): super().__init__() self.encoder = Encoder(io_channels, io_features, latent_features, channels) channels.reverse() self.decoder = Decoder(io_channels, latent_features, io_features, channels) self.latent_features = latent_features self.audio_loss_func = MultiResolutionSTFTLoss() self.learning_rate = learning_rate @torch.no_grad() def sample(self, eps=None): if eps is None: eps = torch.rand((1, self.latent_features)) return self.decoder(eps) def loss_function(self, x, x_hat, mean, logvar): audio_loss = self.audio_loss_func(x, x_hat) kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp()) return audio_loss + kld_loss def reparameterize(self, mean, logvar): std= torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mean def forward(self, x): mean, logvar = self.encoder(x) z = self.reparameterize(mean, logvar) return self.decoder(z), mean, logvar def training_step(self, batch: Tensor, batch_idx: int, log: bool = True) -> Tensor: x_hat, mean, logvar = self.forward(batch) loss = self.loss_function(batch, x_hat, mean, logvar) if log: self.log("train_loss", loss, prog_bar=True) return loss def configure_optimizers(self) -> Optimizer: optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate) return optimizer class CVAE(L.LightningModule): def __init__(self, io_channels: int, io_features: int, latent_features: int, channels: list, num_classes: int, learning_rate: float): super().__init__() self.class_embedder = nn.Linear(num_classes, io_features) self.data_embedder = nn.Conv1d(io_channels, io_channels, kernel_size=1) self.encoder = Encoder(io_channels+1, io_features, latent_features, channels) channels.reverse() self.decoder = Decoder(io_channels, latent_features+num_classes, io_features, channels) self.num_classes = num_classes self.latent_features = latent_features self.audio_loss_func = MultiResolutionSTFTLoss() self.learning_rate = learning_rate @torch.no_grad() def sample(self, c, eps=None): c = nn.functional.one_hot(c, num_classes=self.num_classes).float().unsqueeze(0) if eps is None: eps = torch.rand((1, self.latent_features)) z = torch.cat([eps, c], dim=1) return self.decoder(z) def loss_function(self, x, x_hat, mean, logvar): audio_loss = self.audio_loss_func(x, x_hat) kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp()) return audio_loss + kld_loss def reparameterize(self, mean, logvar): std= torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mean def forward(self, x, c): c = nn.functional.one_hot(c, num_classes=self.num_classes).float() c_embedding = self.class_embedder(c).unsqueeze(1) x_embedding = self.data_embedder(x) x = torch.cat([x_embedding, c_embedding], dim = 1) mean, logvar = self.encoder(x) z = self.reparameterize(mean, logvar) z = torch.cat([z, c], dim = 1) return self.decoder(z), mean, logvar def training_step(self, batch: Tensor, batch_idx: int, log: bool = True) -> Tensor: x, c = batch x_hat, mean, logvar = self.forward(x, c) loss = self.loss_function(x, x_hat, mean, logvar) if log: self.log("train_loss", loss, prog_bar=True) return loss def configure_optimizers(self) -> Optimizer: optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate) return optimizer