generation / cvae /models.py
acanivet's picture
v1
bdac835
raw
history blame
No virus
6.02 kB
import torch
from torch import nn, Tensor
from torch.optim import Optimizer
from .blocks import UpResConvBlock, DownResConvBlock
import lightning as L
from auraloss.freq import MultiResolutionSTFTLoss
class Encoder(nn.Module):
def __init__(self,
in_channels: int,
in_features: int,
out_features: int,
channels: list = None,
) -> None:
super(Encoder, self).__init__()
assert in_features % 2**len(channels) == 0, f"in_features ({in_features}) must be a multiple of downscale factor ({2**len(channels)})"
modules = [
nn.Conv1d(in_channels, channels[0], 1),
nn.GELU()
]
for in_channel, out_channel in zip(channels, channels[1:]+[channels[-1]]):
modules += [
DownResConvBlock(in_channel, out_channel, 1),
]
n_features = int(in_features*.5**len(channels))
modules += [
nn.Flatten(),
nn.Linear(n_features*channels[-1], 2*out_features)
]
self.net = nn.Sequential(*modules)
def forward(self, x):
mean, logvar = self.net(x).chunk(2, dim=1)
return mean, logvar
class Decoder(nn.Module):
def __init__(self,
out_channels: int,
in_features: int,
out_features: int,
channels: list = None,
) -> None:
super(Decoder, self).__init__()
n_features = int(out_features/2**len(channels))
modules = [
nn.Linear(in_features, n_features*channels[0]),
nn.Unflatten(-1, (channels[0], n_features))
]
for in_channel, out_channel in zip(channels, channels[1:]+[channels[-1]]):
modules += [
UpResConvBlock(in_channel, out_channel, 1),
]
modules += [
nn.Conv1d(channels[-1], out_channels, 1),
nn.GELU()
]
self.net = nn.Sequential(*modules)
def forward(self, x):
x = torch.tanh(self.net(x))
return x
class VAE(L.LightningModule):
def __init__(self, io_channels: int, io_features: int, latent_features: int, channels: list, learning_rate: float):
super().__init__()
self.encoder = Encoder(io_channels, io_features, latent_features, channels)
channels.reverse()
self.decoder = Decoder(io_channels, latent_features, io_features, channels)
self.latent_features = latent_features
self.audio_loss_func = MultiResolutionSTFTLoss()
self.learning_rate = learning_rate
@torch.no_grad()
def sample(self, eps=None):
if eps is None:
eps = torch.rand((1, self.latent_features))
return self.decoder(eps)
def loss_function(self, x, x_hat, mean, logvar):
audio_loss = self.audio_loss_func(x, x_hat)
kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
return audio_loss + kld_loss
def reparameterize(self, mean, logvar):
std= torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mean
def forward(self, x):
mean, logvar = self.encoder(x)
z = self.reparameterize(mean, logvar)
return self.decoder(z), mean, logvar
def training_step(self, batch: Tensor, batch_idx: int, log: bool = True) -> Tensor:
x_hat, mean, logvar = self.forward(batch)
loss = self.loss_function(batch, x_hat, mean, logvar)
if log: self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self) -> Optimizer:
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
return optimizer
class CVAE(L.LightningModule):
def __init__(self, io_channels: int, io_features: int, latent_features: int, channels: list, num_classes: int, learning_rate: float):
super().__init__()
self.class_embedder = nn.Linear(num_classes, io_features)
self.data_embedder = nn.Conv1d(io_channels, io_channels, kernel_size=1)
self.encoder = Encoder(io_channels+1, io_features, latent_features, channels)
channels.reverse()
self.decoder = Decoder(io_channels, latent_features+num_classes, io_features, channels)
self.num_classes = num_classes
self.latent_features = latent_features
self.audio_loss_func = MultiResolutionSTFTLoss()
self.learning_rate = learning_rate
@torch.no_grad()
def sample(self, c, eps=None):
c = nn.functional.one_hot(c, num_classes=self.num_classes).float().unsqueeze(0)
if eps is None:
eps = torch.rand((1, self.latent_features))
z = torch.cat([eps, c], dim=1)
return self.decoder(z)
def loss_function(self, x, x_hat, mean, logvar):
audio_loss = self.audio_loss_func(x, x_hat)
kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
return audio_loss + kld_loss
def reparameterize(self, mean, logvar):
std= torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mean
def forward(self, x, c):
c = nn.functional.one_hot(c, num_classes=self.num_classes).float()
c_embedding = self.class_embedder(c).unsqueeze(1)
x_embedding = self.data_embedder(x)
x = torch.cat([x_embedding, c_embedding], dim = 1)
mean, logvar = self.encoder(x)
z = self.reparameterize(mean, logvar)
z = torch.cat([z, c], dim = 1)
return self.decoder(z), mean, logvar
def training_step(self, batch: Tensor, batch_idx: int, log: bool = True) -> Tensor:
x, c = batch
x_hat, mean, logvar = self.forward(x, c)
loss = self.loss_function(x, x_hat, mean, logvar)
if log: self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self) -> Optimizer:
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
return optimizer