import torch import torch.nn as nn import torch.nn.functional as F class CogVideoXUpsample3D(nn.Module): r""" A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase. Args: in_channels (`int`): Number of channels in the input image. out_channels (`int`): Number of channels produced by the convolution. kernel_size (`int`, defaults to `3`): Size of the convolving kernel. stride (`int`, defaults to `1`): Stride of the convolution. padding (`int`, defaults to `1`): Padding added to all four sides of the input. compress_time (`bool`, defaults to `False`): Whether or not to compress the time dimension. """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, padding: int = 1, compress_time: bool = False, ) -> None: super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) self.compress_time = compress_time def forward(self, inputs: torch.Tensor) -> torch.Tensor: if self.compress_time: if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1: # split first frame x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:] x_first = F.interpolate(x_first, scale_factor=2.0) x_rest = F.interpolate(x_rest, scale_factor=2.0) x_first = x_first[:, :, None, :, :] inputs = torch.cat([x_first, x_rest], dim=2) elif inputs.shape[2] > 1: inputs = F.interpolate(inputs, scale_factor=2.0) else: inputs = inputs.squeeze(2) inputs = F.interpolate(inputs, scale_factor=2.0) inputs = inputs[:, :, None, :, :] else: # only interpolate 2D b, c, t, h, w = inputs.shape inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) inputs = F.interpolate(inputs, scale_factor=2.0) inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4) b, c, t, h, w = inputs.shape inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) inputs = self.conv(inputs) inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4) return inputs