import torch from torch import nn, Tensor from torch.optim import Optimizer from .blocks import UpResConvBlock, DownResConvBlock import lightning as L from auraloss.freq import MultiResolutionSTFTLoss from typing import Sequence class Encoder(nn.Module): def __init__( self, in_channels: int, in_features: int, out_features: int, channels: Sequence[int], ) -> None: super(Encoder, self).__init__() assert ( in_features % 2 ** len(channels) == 0 ), f"in_features ({in_features}) must be a multiple of downscale factor ({2**len(channels)})" modules = [nn.Conv1d(in_channels, channels[0], 1), nn.GELU()] for in_channel, out_channel in zip(channels, channels[1:] + [channels[-1]]): modules += [ DownResConvBlock(in_channel, out_channel, 1), ] n_features = int(in_features * 0.5 ** len(channels)) modules += [ nn.Flatten(), nn.Linear(n_features * channels[-1], 2 * out_features), ] self.net = nn.Sequential(*modules) def forward(self, x: Tensor) -> Tensor: mean, logvar = self.net(x).chunk(2, dim=1) return mean, logvar class Decoder(nn.Module): def __init__( self, out_channels: int, in_features: int, out_features: int, channels: Sequence[int], ) -> None: super(Decoder, self).__init__() n_features = int(out_features / 2 ** len(channels)) modules = [ nn.Linear(in_features, n_features * channels[0]), nn.Unflatten(-1, (channels[0], n_features)), ] for in_channel, out_channel in zip(channels, channels[1:] + [channels[-1]]): modules += [ UpResConvBlock(in_channel, out_channel, 1), ] modules += [nn.Conv1d(channels[-1], out_channels, 1), nn.GELU()] self.net = nn.Sequential(*modules) def forward(self, x: Tensor) -> Tensor: x = torch.tanh(self.net(x)) return x class VAE(L.LightningModule): def __init__( self, io_channels: int, io_features: int, latent_features: int, channels: Sequence[int], learning_rate: float, ) -> None: super().__init__() self.encoder = Encoder(io_channels, io_features, latent_features, channels) channels.reverse() self.decoder = Decoder(io_channels, latent_features, io_features, channels) self.latent_features = latent_features self.audio_loss_func = MultiResolutionSTFTLoss() self.learning_rate = learning_rate @torch.no_grad() def sample(self, eps: Tensor = None) -> Tensor: if eps is None: eps = torch.rand((1, self.latent_features)) return self.decoder(eps) def loss_function( self, x: Tensor, x_hat: Tensor, mean: Tensor, logvar: Tensor ) -> Tensor: audio_loss = self.audio_loss_func(x, x_hat) kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp()) return audio_loss + kld_loss def reparameterize(self, mean: Tensor, logvar: Tensor) -> Tensor: std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mean def forward(self, x: Tensor) -> tuple[Tensor]: mean, logvar = self.encoder(x) z = self.reparameterize(mean, logvar) return self.decoder(z), mean, logvar def training_step(self, batch: Tensor, batch_idx: int, log: bool = True) -> Tensor: x_hat, mean, logvar = self.forward(batch) loss = self.loss_function(batch, x_hat, mean, logvar) if log: self.log("train_loss", loss, prog_bar=True) return loss def configure_optimizers(self) -> Optimizer: optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate) return optimizer class CVAE(L.LightningModule): def __init__( self, io_channels: int, io_features: int, latent_features: int, channels: Sequence[int], num_classes: int, learning_rate: float, ): super().__init__() self.class_embedder = nn.Linear(num_classes, io_features) self.data_embedder = nn.Conv1d(io_channels, io_channels, kernel_size=1) self.encoder = Encoder(io_channels + 1, io_features, latent_features, channels) channels.reverse() self.decoder = Decoder( io_channels, latent_features + num_classes, io_features, channels ) self.num_classes = num_classes self.latent_features = latent_features self.audio_loss_func = MultiResolutionSTFTLoss() self.learning_rate = learning_rate @torch.no_grad() def sample(self, c, eps=None) -> Tensor: c = nn.functional.one_hot(c, num_classes=self.num_classes).float().unsqueeze(0) if eps is None: eps = torch.rand((1, self.latent_features)) z = torch.cat([eps, c], dim=1) return self.decoder(z) def loss_function( self, x: Tensor, x_hat: Tensor, mean: Tensor, logvar: Tensor ) -> Tensor: audio_loss = self.audio_loss_func(x, x_hat) kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp()) return audio_loss + kld_loss def reparameterize(self, mean: Tensor, logvar: Tensor) -> Tensor: std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mean def forward(self, x: Tensor, c: Tensor) -> tuple[Tensor]: c = nn.functional.one_hot(c, num_classes=self.num_classes).float() c_embedding = self.class_embedder(c).unsqueeze(1) x_embedding = self.data_embedder(x) x = torch.cat([x_embedding, c_embedding], dim=1) mean, logvar = self.encoder(x) z = self.reparameterize(mean, logvar) z = torch.cat([z, c], dim=1) return self.decoder(z), mean, logvar def training_step(self, batch: Tensor, batch_idx: int, log: bool = True) -> Tensor: x, c = batch x_hat, mean, logvar = self.forward(x, c) loss = self.loss_function(x, x_hat, mean, logvar) if log: self.log("train_loss", loss, prog_bar=True) return loss def configure_optimizers(self) -> Optimizer: optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate) return optimizer