generation / cvae /models.py
acanivet's picture
Formatting
e72f4c2
raw
history blame
No virus
6.47 kB
import torch
from torch import nn, Tensor
from torch.optim import Optimizer
from .blocks import UpResConvBlock, DownResConvBlock
import lightning as L
from auraloss.freq import MultiResolutionSTFTLoss
from typing import Sequence
class Encoder(nn.Module):
def __init__(
self,
in_channels: int,
in_features: int,
out_features: int,
channels: Sequence[int],
) -> None:
super(Encoder, self).__init__()
assert (
in_features % 2 ** len(channels) == 0
), f"in_features ({in_features}) must be a multiple of downscale factor ({2**len(channels)})"
modules = [nn.Conv1d(in_channels, channels[0], 1), nn.GELU()]
for in_channel, out_channel in zip(channels, channels[1:] + [channels[-1]]):
modules += [
DownResConvBlock(in_channel, out_channel, 1),
]
n_features = int(in_features * 0.5 ** len(channels))
modules += [
nn.Flatten(),
nn.Linear(n_features * channels[-1], 2 * out_features),
]
self.net = nn.Sequential(*modules)
def forward(self, x: Tensor) -> Tensor:
mean, logvar = self.net(x).chunk(2, dim=1)
return mean, logvar
class Decoder(nn.Module):
def __init__(
self,
out_channels: int,
in_features: int,
out_features: int,
channels: Sequence[int],
) -> None:
super(Decoder, self).__init__()
n_features = int(out_features / 2 ** len(channels))
modules = [
nn.Linear(in_features, n_features * channels[0]),
nn.Unflatten(-1, (channels[0], n_features)),
]
for in_channel, out_channel in zip(channels, channels[1:] + [channels[-1]]):
modules += [
UpResConvBlock(in_channel, out_channel, 1),
]
modules += [nn.Conv1d(channels[-1], out_channels, 1), nn.GELU()]
self.net = nn.Sequential(*modules)
def forward(self, x: Tensor) -> Tensor:
x = torch.tanh(self.net(x))
return x
class VAE(L.LightningModule):
def __init__(
self,
io_channels: int,
io_features: int,
latent_features: int,
channels: Sequence[int],
learning_rate: float,
) -> None:
super().__init__()
self.encoder = Encoder(io_channels, io_features, latent_features, channels)
channels.reverse()
self.decoder = Decoder(io_channels, latent_features, io_features, channels)
self.latent_features = latent_features
self.audio_loss_func = MultiResolutionSTFTLoss()
self.learning_rate = learning_rate
@torch.no_grad()
def sample(self, eps: Tensor = None) -> Tensor:
if eps is None:
eps = torch.rand((1, self.latent_features))
return self.decoder(eps)
def loss_function(
self, x: Tensor, x_hat: Tensor, mean: Tensor, logvar: Tensor
) -> Tensor:
audio_loss = self.audio_loss_func(x, x_hat)
kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
return audio_loss + kld_loss
def reparameterize(self, mean: Tensor, logvar: Tensor) -> Tensor:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mean
def forward(self, x: Tensor) -> tuple[Tensor]:
mean, logvar = self.encoder(x)
z = self.reparameterize(mean, logvar)
return self.decoder(z), mean, logvar
def training_step(self, batch: Tensor, batch_idx: int, log: bool = True) -> Tensor:
x_hat, mean, logvar = self.forward(batch)
loss = self.loss_function(batch, x_hat, mean, logvar)
if log:
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self) -> Optimizer:
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
return optimizer
class CVAE(L.LightningModule):
def __init__(
self,
io_channels: int,
io_features: int,
latent_features: int,
channels: Sequence[int],
num_classes: int,
learning_rate: float,
):
super().__init__()
self.class_embedder = nn.Linear(num_classes, io_features)
self.data_embedder = nn.Conv1d(io_channels, io_channels, kernel_size=1)
self.encoder = Encoder(io_channels + 1, io_features, latent_features, channels)
channels.reverse()
self.decoder = Decoder(
io_channels, latent_features + num_classes, io_features, channels
)
self.num_classes = num_classes
self.latent_features = latent_features
self.audio_loss_func = MultiResolutionSTFTLoss()
self.learning_rate = learning_rate
@torch.no_grad()
def sample(self, c, eps=None) -> Tensor:
c = nn.functional.one_hot(c, num_classes=self.num_classes).float().unsqueeze(0)
if eps is None:
eps = torch.rand((1, self.latent_features))
z = torch.cat([eps, c], dim=1)
return self.decoder(z)
def loss_function(
self, x: Tensor, x_hat: Tensor, mean: Tensor, logvar: Tensor
) -> Tensor:
audio_loss = self.audio_loss_func(x, x_hat)
kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
return audio_loss + kld_loss
def reparameterize(self, mean: Tensor, logvar: Tensor) -> Tensor:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mean
def forward(self, x: Tensor, c: Tensor) -> tuple[Tensor]:
c = nn.functional.one_hot(c, num_classes=self.num_classes).float()
c_embedding = self.class_embedder(c).unsqueeze(1)
x_embedding = self.data_embedder(x)
x = torch.cat([x_embedding, c_embedding], dim=1)
mean, logvar = self.encoder(x)
z = self.reparameterize(mean, logvar)
z = torch.cat([z, c], dim=1)
return self.decoder(z), mean, logvar
def training_step(self, batch: Tensor, batch_idx: int, log: bool = True) -> Tensor:
x, c = batch
x_hat, mean, logvar = self.forward(x, c)
loss = self.loss_function(x, x_hat, mean, logvar)
if log:
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self) -> Optimizer:
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
return optimizer