from torch import nn, Tensor class UpResConvBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size) -> None: super(UpResConvBlock, self).__init__() self.residual = nn.Sequential( nn.Upsample(scale_factor=2), nn.Conv1d(in_channels, out_channels, 1, 1, bias=False), ) self.main = nn.Sequential( nn.Upsample(scale_factor=2), nn.Conv1d(in_channels, out_channels, kernel_size, 1), nn.GroupNorm(1, out_channels), nn.GELU(), nn.Conv1d(out_channels, out_channels, kernel_size, 1), nn.GroupNorm(1, out_channels), nn.GELU() ) def forward(self, x: Tensor) -> Tensor: return self.main(x) + self.residual(x) class DownResConvBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size) -> None: super(DownResConvBlock, self).__init__() self.residual = nn.Conv1d(in_channels, out_channels, 1, 2, bias=False) self.main = nn.Sequential( nn.Conv1d(in_channels, out_channels, kernel_size, 2), nn.GroupNorm(1, out_channels), nn.GELU(), nn.Conv1d(out_channels, out_channels, kernel_size, 1), nn.GroupNorm(1, out_channels), nn.GELU() ) def forward(self, x: Tensor) -> Tensor: return self.main(x) + self.residual(x) class ResConvBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size) -> None: super(ResConvBlock, self).__init__() self.residual = nn.Identity() if in_channels == out_channels else nn.Conv1d(in_channels, out_channels, 1, bias=False) self.main = nn.Sequential( nn.Conv1d(in_channels, out_channels, kernel_size), nn.GroupNorm(1, out_channels), nn.GELU(), nn.Conv1d(out_channels, out_channels, kernel_size), nn.GroupNorm(1, out_channels), nn.GELU() ) def forward(self, x: Tensor) -> Tensor: return self.main(x) + self.residual(x)