Spaces:
Sleeping
Sleeping
from torch import nn | |
class UpResConvBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size): | |
super(UpResConvBlock, self).__init__() | |
self.residual = nn.Sequential( | |
nn.Upsample(scale_factor=2), | |
nn.Conv1d(in_channels, out_channels, 1, 1, bias=False), | |
) | |
self.main = nn.Sequential( | |
nn.Upsample(scale_factor=2), | |
nn.Conv1d(in_channels, out_channels, kernel_size, 1), | |
nn.GroupNorm(1, out_channels), | |
nn.GELU(), | |
nn.Conv1d(out_channels, out_channels, kernel_size, 1), | |
nn.GroupNorm(1, out_channels), | |
nn.GELU() | |
) | |
def forward(self, x): | |
return self.main(x) + self.residual(x) | |
class DownResConvBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size): | |
super(DownResConvBlock, self).__init__() | |
self.residual = nn.Conv1d(in_channels, out_channels, 1, 2, bias=False) | |
self.main = nn.Sequential( | |
nn.Conv1d(in_channels, out_channels, kernel_size, 2), | |
nn.GroupNorm(1, out_channels), | |
nn.GELU(), | |
nn.Conv1d(out_channels, out_channels, kernel_size, 1), | |
nn.GroupNorm(1, out_channels), | |
nn.GELU() | |
) | |
def forward(self, x): | |
return self.main(x) + self.residual(x) | |
class ResConvBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size): | |
super(ResConvBlock, self).__init__() | |
self.residual = nn.Identity() if in_channels == out_channels else nn.Conv1d(in_channels, out_channels, 1, bias=False) | |
self.main = nn.Sequential( | |
nn.Conv1d(in_channels, out_channels, kernel_size), | |
nn.GroupNorm(1, out_channels), | |
nn.GELU(), | |
nn.Conv1d(out_channels, out_channels, kernel_size), | |
nn.GroupNorm(1, out_channels), | |
nn.GELU() | |
) | |
def forward(self, x): | |
return self.main(x) + self.residual(x) |