import numpy as np import torch import torch.nn as nn import torchvision import torchvision.models as models from typing import Union, List, Tuple import os import video3d.utils.misc as misc import torch.nn.functional as F from siren_pytorch import SirenNet from video3d.triplane_texture.lift_architecture import Lift_Encoder from video3d.triplane_texture.triplane_transformer import Triplane_Transformer EPS = 1e-7 def get_activation(name, inplace=True, lrelu_param=0.2): if name == 'tanh': return nn.Tanh() elif name == 'sigmoid': return nn.Sigmoid() elif name == 'relu': return nn.ReLU(inplace=inplace) elif name == 'lrelu': return nn.LeakyReLU(lrelu_param, inplace=inplace) else: raise NotImplementedError class MLPWithPositionalEncoding(nn.Module): def __init__(self, cin, cout, num_layers, nf=256, dropout=0, activation=None, n_harmonic_functions=10, omega0=1, extra_dim=0, embed_concat_pts=True, symmetrize=False): super().__init__() self.extra_dim = extra_dim if n_harmonic_functions > 0: self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0) dim_in = cin * 2 * n_harmonic_functions self.embed_concat_pts = embed_concat_pts if embed_concat_pts: dim_in += cin else: self.embedder = None dim_in = cin self.in_layer = nn.Linear(dim_in, nf) self.relu = nn.ReLU(inplace=True) self.mlp = MLP(nf + extra_dim, cout, num_layers, nf, dropout, activation) self.symmetrize = symmetrize def forward(self, x, feat=None): assert (feat is None and self.extra_dim == 0) or feat.shape[-1] == self.extra_dim if self.symmetrize: xs, ys, zs = x.unbind(-1) x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x if self.embedder is not None: x_in = self.embedder(x) if self.embed_concat_pts: x_in = torch.cat([x, x_in], -1) else: x_in = x x_in = self.relu(self.in_layer(x_in)) if feat is not None: # if len(feat.shape) == 1: # for _ in range(len(x_in.shape) - 1): # feat = feat.unsqueeze(0) # feat = feat.repeat(*x_in.shape[:-1], 1) x_in = torch.concat([x_in, feat], dim=-1) return self.mlp(x_in) class MLPWithPositionalEncoding_Style(nn.Module): def __init__(self, cin, cout, num_layers, nf=256, dropout=0, activation=None, n_harmonic_functions=10, omega0=1, extra_dim=0, embed_concat_pts=True, symmetrize=False, style_choice='film'): super().__init__() self.extra_dim = extra_dim if n_harmonic_functions > 0: self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0) dim_in = cin * 2 * n_harmonic_functions self.embed_concat_pts = embed_concat_pts if embed_concat_pts: dim_in += cin else: self.embedder = None dim_in = cin self.in_layer = nn.Linear(dim_in, nf) self.relu = nn.ReLU(inplace=True) if extra_dim == 0: self.mlp = MLP(nf + extra_dim, cout, num_layers, nf, dropout, activation) else: if style_choice == 'film': self.mlp = MLP_FiLM(nf, cout, num_layers, nf, dropout, activation) self.style_mlp = MLP(extra_dim, nf*2, 2, nf, dropout, None) elif style_choice == 'mod': self.mlp = MLP_Mod(nf, cout, num_layers, nf, dropout, activation) self.style_mlp = MLP(extra_dim, nf, 2, nf, dropout, None) else: raise NotImplementedError self.style_choice = style_choice self.symmetrize = symmetrize def forward(self, x, feat=None): assert (feat is None and self.extra_dim == 0) or feat.shape[-1] == self.extra_dim if self.symmetrize: xs, ys, zs = x.unbind(-1) x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x if self.embedder is not None: x_in = self.embedder(x) if self.embed_concat_pts: x_in = torch.cat([x, x_in], -1) else: x_in = x x_in = self.relu(self.in_layer(x_in)) if feat is not None: style = self.style_mlp(feat) if self.style_choice == 'film': style = style.reshape(style.shape[:-1] + (-1, 2)) out = self.mlp(x_in, style) else: out = self.mlp(x_in) return out class MLP_FiLM(nn.Module): def __init__(self, cin, cout, num_layers, nf=256, dropout=0, activation=None): # default no dropout super().__init__() assert num_layers >= 1 self.num_layers = num_layers if num_layers == 1: self.network = Linear_FiLM(cin, cout, bias=False) else: self.relu = nn.ReLU(inplace=True) for i in range(num_layers): if i == 0: setattr(self, f'linear_{i}', Linear_FiLM(cin, nf, bias=False)) elif i == (num_layers-1): setattr(self, f'linear_{i}', Linear_FiLM(nf, cout, bias=False)) else: setattr(self, f'linear_{i}', Linear_FiLM(nf, nf, bias=False)) def forward(self, input, style): if self.num_layers == 1: out = self.network(input, style) else: x = input for i in range(self.num_layers): linear_layer = getattr(self, f'linear_{i}') if i == (self.num_layers - 1): x = linear_layer(x, style) else: x = linear_layer(x, style) x = self.relu(x) out = x return out class MLP_Mod(nn.Module): def __init__(self, cin, cout, num_layers, nf=256, dropout=0, activation=None): # default no dropout super().__init__() assert num_layers >= 1 self.num_layers = num_layers if num_layers == 1: self.network = Linear_Mod(cin, cout, bias=False) else: self.relu = nn.ReLU(inplace=True) for i in range(num_layers): if i == 0: setattr(self, f'linear_{i}', Linear_Mod(cin, nf, bias=False)) elif i == (num_layers-1): setattr(self, f'linear_{i}', Linear_Mod(nf, cout, bias=False)) else: setattr(self, f'linear_{i}', Linear_Mod(nf, nf, bias=False)) def forward(self, input, style): if self.num_layers == 1: out = self.network(input, style) else: x = input for i in range(self.num_layers): linear_layer = getattr(self, f'linear_{i}') if i == (self.num_layers - 1): x = linear_layer(x, style) else: x = linear_layer(x, style) x = self.relu(x) out = x return out import math class Linear_FiLM(nn.Module): def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.in_features = in_features self.out_features = out_features self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs)) if bias: self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self) -> None: nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(self.bias, -bound, bound) def forward(self, input, style): # if input is [..., D], style should be [..., D, 2] x = input * style[..., 0] + style[..., 1] return torch.nn.functional.linear(x, self.weight, self.bias) def extra_repr(self) -> str: return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None ) class Linear_Mod(nn.Module): def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.in_features = in_features self.out_features = out_features self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs)) if bias: self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self) -> None: nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(self.bias, -bound, bound) def forward(self, input, style): # weight: [out_features, in_features] # style: [..., in_features] if len(style.shape) > 1: style = style.reshape(-1, style.shape[-1]) style = style[0] weight = self.weight * style.unsqueeze(0) decoefs = ((weight * weight).sum(dim=-1, keepdim=True) + 1e-5).sqrt() weight = weight / decoefs return torch.nn.functional.linear(input, weight, self.bias) def extra_repr(self) -> str: return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None ) class MLPTextureSimple(nn.Module): def __init__(self, cin, cout, num_layers, nf=256, dropout=0, activation=None, min_max=None, n_harmonic_functions=10, omega0=1, extra_dim=0, embed_concat_pts=True, perturb_normal=False, symmetrize=False, texture_act='relu', linear_bias=False): super().__init__() self.extra_dim = extra_dim if n_harmonic_functions > 0: self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0) dim_in = cin * 2 * n_harmonic_functions self.embed_concat_pts = embed_concat_pts if embed_concat_pts: dim_in += cin else: self.embedder = None dim_in = cin self.in_layer = nn.Linear(dim_in, nf) self.relu = nn.ReLU(inplace=True) if texture_act == 'sin': print('using siren network for texture mlp here') self.mlp = SirenNet( dim_in=(nf + extra_dim), dim_hidden=nf, dim_out=cout, num_layers=num_layers, final_activation=get_activation(activation), w0_initial=30, use_bias=linear_bias, dropout=dropout ) else: self.mlp = MLP(nf + extra_dim, cout, num_layers, nf, dropout, activation, inner_act=texture_act, linear_bias=linear_bias) self.perturb_normal = perturb_normal self.symmetrize = symmetrize if min_max is not None: self.register_buffer('min_max', min_max) else: self.min_max = None self.bsdf = None def sample(self, x, feat=None): assert (feat is None and self.extra_dim == 0) or (feat.shape[-1] == self.extra_dim) b, h, w, c = x.shape if self.symmetrize: xs, ys, zs = x.unbind(-1) x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x x = x.view(-1, c) if self.embedder is not None: x_in = self.embedder(x) if self.embed_concat_pts: x_in = torch.cat([x, x_in], -1) else: x_in = x x_in = self.in_layer(x_in) if feat is not None: feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1) x_in = torch.concat([x_in, feat], dim=-1) out = self.mlp(self.relu(x_in)) if self.min_max is not None: out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :] return out.view(b, h, w, -1) class MLPTextureTriplane(nn.Module): def __init__(self, cin, cout, num_layers, nf=256, dropout=0, activation=None, min_max=None, n_harmonic_functions=10, omega0=1, extra_dim=0, embed_concat_pts=True, perturb_normal=False, symmetrize=False, texture_act='relu', linear_bias=False, cam_pos_z_offset=10., grid_scale=7,): super().__init__() self.extra_dim = extra_dim if n_harmonic_functions > 0: self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0) dim_in = cin * 2 * n_harmonic_functions self.embed_concat_pts = embed_concat_pts if embed_concat_pts: dim_in += cin else: self.embedder = None dim_in = cin self.in_layer = nn.Linear(dim_in, nf) self.relu = nn.ReLU(inplace=True) self.feat_net = Triplane_Transformer( emb_dim=256, num_layers=8, triplane_dim=80, triplane_scale=grid_scale ) self.extra_dim -= extra_dim self.extra_dim += (self.feat_net.triplane_dim * 3) if texture_act == 'sin': print('using siren network for texture mlp here') self.mlp = SirenNet( dim_in=(nf + self.extra_dim), dim_hidden=nf, dim_out=cout, num_layers=num_layers, final_activation=get_activation(activation), w0_initial=30, use_bias=linear_bias, dropout=dropout ) else: self.mlp = MLP(nf + self.extra_dim, cout, num_layers, nf, dropout, activation, inner_act=texture_act, linear_bias=linear_bias) self.perturb_normal = perturb_normal self.symmetrize = symmetrize if min_max is not None: self.register_buffer('min_max', min_max) else: self.min_max = None self.bsdf = None def sample(self, x, feat=None, feat_map=None, mvp=None, w2c=None, deform_xyz=None): # assert (feat is None and self.extra_dim == 0) or (feat.shape[-1] == self.extra_dim) b, h, w, c = x.shape if self.symmetrize: xs, ys, zs = x.unbind(-1) x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x if isinstance(feat_map, dict): feat_map = feat_map["im_features_map"] feat_map = feat_map.permute(0, 2, 3, 1) _, ph, pw, _ = feat_map.shape feat_map = feat_map.reshape(feat_map.shape[0], ph*pw, feat_map.shape[-1]) pts_feat = self.feat_net(feat_map, x.reshape(b, -1, 3)) pts_c = pts_feat.shape[-1] pts_feat = pts_feat.reshape(-1, pts_c) x = x.view(-1, c) if self.embedder is not None: x_in = self.embedder(x) if self.embed_concat_pts: x_in = torch.cat([x, x_in], -1) else: x_in = x x_in = self.in_layer(x_in) x_in = torch.concat([x_in, pts_feat], dim=-1) out = self.mlp(self.relu(x_in)) if self.min_max is not None: out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :] return out.view(b, h, w, -1) class LocalFeatureBlock(nn.Module): def __init__(self, local_feat_dim, input_dim=384, output_dim=384, upscale_num=3): super().__init__() self.local_feat_dim = local_feat_dim self.conv_list = nn.ModuleList([]) self.upscale_list = nn.ModuleList([]) for i in range(upscale_num): if i == 0: self.conv_list.append(nn.Conv2d(input_dim, 4 * local_feat_dim, 3, stride=1, padding=1, dilation=1)) else: self.conv_list.append(nn.Conv2d(local_feat_dim, 4 * local_feat_dim, 3, stride=1, padding=1, dilation=1)) self.upscale_list.append(nn.PixelShuffle(2)) self.conv_head = nn.Conv2d(local_feat_dim, output_dim, 3, stride=1, padding=1, dilation=1) def forward(self, x): for idx, conv in enumerate(self.conv_list): x = conv(x) x = self.upscale_list[idx](x) out = self.conv_head(x) return out class MLPTextureLocal(nn.Module): def __init__(self, cin, cout, num_layers, nf=256, dropout=0, activation=None, min_max=None, n_harmonic_functions=10, omega0=1, extra_dim=0, embed_concat_pts=True, perturb_normal=False, symmetrize=False, texture_way=None, larger_tex_dim=False, cam_pos_z_offset=10., grid_scale=7.): super().__init__() self.extra_dim = extra_dim self.cam_pos_z_offset = cam_pos_z_offset self.grid_scale = grid_scale local_feat_dim = 64 assert texture_way is not None self.texture_way = texture_way if 'local' in texture_way and 'global' in texture_way: # self.extra_dim = extra_dim + local_feat_dim self.extra_dim = extra_dim elif 'local' in texture_way and 'global' not in texture_way: # self.extra_dim = local_feat_dim self.extra_dim = extra_dim elif 'local' not in texture_way and 'global' in texture_way: self.extra_dim = extra_dim if n_harmonic_functions > 0: self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0) dim_in = cin * 2 * n_harmonic_functions self.embed_concat_pts = embed_concat_pts if embed_concat_pts: dim_in += cin else: self.embedder = None dim_in = cin # self.local_feature_block = LocalFeatureBlock(local_feat_dim=local_feat_dim, input_dim=384, output_dim=256) self.local_feature_block = nn.Linear(384, nf, bias=False) self.in_layer = nn.Linear(dim_in, nf) self.relu = nn.ReLU(inplace=True) self.mlp = MLP(nf + self.extra_dim, cout, num_layers, nf, dropout, activation) self.perturb_normal = perturb_normal self.symmetrize = symmetrize if min_max is not None: self.register_buffer('min_max', min_max) else: self.min_max = None self.bsdf = None def get_uv_depth(self, xyz, mvp): # xyz: [b, k, 3] # mvp: [b, 4, 4] cam4 = torch.matmul(torch.nn.functional.pad(xyz, pad=(0,1), mode='constant', value=1.0), torch.transpose(mvp, 1, 2)) cam3 = cam4[..., :3] / cam4[..., 3:4] cam_uv = cam3[..., :2] # cam_uv = cam_uv.detach() cam_depth = cam3 + torch.FloatTensor([0, 0, self.cam_pos_z_offset]).to(xyz.device).view(1, 1, 3) cam_depth = cam_depth / self.grid_scale * 2 cam_depth = cam_depth[..., 2:3] # cam_depth = cam_depth.detach() return cam_uv, cam_depth def proj_sample_deform(self, xyz, feat_map, mvp, w2c, img_h, img_w): # here the xyz is deformed points # and we don't cast any symmtery here b, k, c = xyz.shape THRESHOLD = 1e-4 if isinstance(feat_map, torch.Tensor): coordinates = xyz # use pre-symmetry points to get feature and record depth cam_uv, cam_depth = self.get_uv_depth(coordinates, mvp) cam_uv = cam_uv.detach() cam_depth = cam_depth.detach() # get local feature feature = F.grid_sample(feat_map, cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c] self.input_depth = cam_depth.reshape(b, 256, 256, 1) # [B, 256, 256, 1] self.input_pts = coordinates.detach() elif isinstance(feat_map, dict): original_mvp = feat_map['original_mvp'] local_feat_map = feat_map['im_features_map'] original_depth = self.input_depth[0:b] coordinates = xyz cam_uv, cam_depth = self.get_uv_depth(coordinates, original_mvp) cam_uv = cam_uv.detach() cam_depth = cam_depth.detach() project_feature = F.grid_sample(local_feat_map, cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c] project_depth = F.grid_sample(original_depth.permute(0, 3, 1, 2), cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, 1] use_mask = cam_depth <= project_depth + THRESHOLD feature = project_feature * use_mask.repeat(1, 1, project_feature.shape[-1]) ret_feature = self.local_feature_block(feature.reshape(b*k, -1)) # the linear is without bias, so 0 value feature will still get 0 value return ret_feature def proj_sample(self, xyz, feat_map, mvp, w2c, img_h, img_w, xyz_before_sym=None): # the new one with no input feature map upsampling # feat_map: [B, C, H, W] b, k, c = xyz.shape if isinstance(feat_map, torch.Tensor): if xyz_before_sym is None: coordinates = xyz else: coordinates = xyz_before_sym # use pre-symmetry points to get feature and record depth cam_uv, cam_depth = self.get_uv_depth(coordinates, mvp) cam_uv = cam_uv.detach() cam_depth = cam_depth.detach() # get local feature feature = F.grid_sample(feat_map, cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c] self.input_depth = cam_depth.reshape(b, 256, 256, 1) # [B, 256, 256, 1] self.input_pts = coordinates.detach() elif isinstance(feat_map, dict): original_mvp = feat_map['original_mvp'] local_feat_map = feat_map['im_features_map'] THRESHOLD = 1e-4 original_depth = self.input_depth[0:b] # if b == 1: # from pdb import set_trace; set_trace() # tmp_mask = xyz[0].reshape(256, 256, 3).sum(dim=-1) != 0 # tmp_mask = tmp_mask.cpu().numpy() # tmp_mask = tmp_mask * 255 # src_dp = self.input_depth[0,:,:,0].cpu().numpy() # input_pts = self.input_pts[0].cpu().numpy() # input_mask = self.input_pts[0].reshape(256, 256, 3).sum(dim=-1) != 0 # input_mask = input_mask.int().cpu().numpy() # input_mask = input_mask * 255 # np.save('./tmp_save/src_dp.npy', src_dp) # np.save('./tmp_save/input_pts.npy', input_pts) # import cv2 # cv2.imwrite('./tmp_save/input_mask.png', input_mask) # cv2.imwrite('./tmp_save/mask.png', tmp_mask) # test_pts_pos = xyz[0].cpu().numpy() # np.save('./tmp_save/test_pts_pos.npy', test_pts_pos) # test_pts_raw = xyz_before_sym[0].cpu().numpy() # np.save('./tmp_save/test_pts_raw.npy', test_pts_raw) # mvp_now = mvp[0].detach().cpu().numpy() # mvp_original = original_mvp[0].detach().cpu().numpy() # np.save('./tmp_save/mvp_now.npy', mvp_now) # np.save('./tmp_save/mvp_original.npy', mvp_original) if xyz_before_sym is None: # just check the project depth of xyz coordinates = xyz cam_uv, cam_depth = self.get_uv_depth(coordinates, original_mvp) cam_uv = cam_uv.detach() cam_depth = cam_depth.detach() project_feature = F.grid_sample(local_feat_map, cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c] project_depth = F.grid_sample(original_depth.permute(0, 3, 1, 2), cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, 1] use_mask = cam_depth <= project_depth + THRESHOLD feature = project_feature * use_mask.repeat(1, 1, project_feature.shape[-1]) else: # need to double check, but now we are still use symmetry! Even if the two points are all visible in input view coords_inp = xyz x_check, y_check, z_check = xyz.unbind(-1) xyz_check = torch.stack([-1 * x_check, y_check, z_check], -1) coords_rev = xyz_check # we directly use neg-x to get the points of another side uv_inp, dp_inp = self.get_uv_depth(coords_inp, original_mvp) uv_rev, dp_rev = self.get_uv_depth(coords_rev, original_mvp) uv_inp = uv_inp.detach() uv_rev = uv_rev.detach() dp_inp = dp_inp.detach() dp_rev = dp_rev.detach() proj_feat_inp = F.grid_sample(local_feat_map, uv_inp.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c] proj_feat_rev = F.grid_sample(local_feat_map, uv_rev.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c] proj_dp_inp = F.grid_sample(original_depth.permute(0, 3, 1, 2), uv_inp.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, 1] proj_dp_rev = F.grid_sample(original_depth.permute(0, 3, 1, 2), uv_rev.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, 1] use_mask_inp = dp_inp <= proj_dp_inp + THRESHOLD use_mask_rev = dp_rev <= proj_dp_rev + THRESHOLD # for those points we can see in two sides, we use average use_mask_inp = use_mask_inp.int() use_mask_rev = use_mask_rev.int() both_vis = (use_mask_inp == 1) & (use_mask_rev == 1) use_mask_inp[both_vis] = 0.5 use_mask_rev[both_vis] = 0.5 feature = proj_feat_inp * use_mask_inp.repeat(1, 1, proj_feat_inp.shape[-1]) + proj_feat_rev * use_mask_rev.repeat(1, 1, proj_feat_rev.shape[-1]) else: raise NotImplementedError ret_feature = self.local_feature_block(feature.reshape(b*k, -1)) # the linear is without bias, so 0 value feature will still get 0 value return ret_feature def sample(self, x, feat=None, feat_map=None, mvp=None, w2c=None, deform_xyz=None): # assert (feat is None and self.extra_dim == 0) or (feat.shape[-1] <= self.extra_dim) b, h, w, c = x.shape xyz_before_sym = None if self.symmetrize: xyz_before_sym = x.reshape(b, -1, c) xs, ys, zs = x.unbind(-1) x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x mvp = mvp.detach() # [b, 4, 4] w2c = w2c.detach() # [b, 4, 4] pts_xyz = x.reshape(b, -1, c) deform_xyz = deform_xyz.reshape(b, -1, c) if 'global' in self.texture_way and 'local' in self.texture_way: global_feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1) # local_feat = self.proj_sample(pts_xyz, feat_map, mvp, w2c, h, w, xyz_before_sym=xyz_before_sym) local_feat = self.proj_sample_deform(deform_xyz, feat_map, mvp, w2c, h, w) # feature_rep = torch.concat([global_feat, local_feat], dim=-1) feature_rep = global_feat + local_feat elif 'global' not in self.texture_way and 'local' in self.texture_way: # local_feat = self.proj_sample(pts_xyz, feat_map, mvp, w2c, h, w, xyz_before_sym=xyz_before_sym) local_feat = self.proj_sample_deform(deform_xyz, feat_map, mvp, w2c, h, w) feature_rep = local_feat elif 'global' in self.texture_way and 'local' not in self.texture_way: global_feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1) feature_rep = global_feat else: global_feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1) feature_rep = global_feat x = x.view(-1, c) if self.embedder is not None: x_in = self.embedder(x) if self.embed_concat_pts: x_in = torch.cat([x, x_in], -1) else: x_in = x x_in = self.in_layer(x_in) # if feat is not None: # feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1) # x_in = torch.concat([x_in, feat], dim=-1) x_in = torch.concat([x_in, feature_rep], dim=-1) out = self.mlp(self.relu(x_in)) if self.min_max is not None: out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :] return out.view(b, h, w, -1) class LiftTexture(nn.Module): def __init__(self, cin, cout, num_layers, nf=256, dropout=0, activation=None, min_max=None, n_harmonic_functions=10, omega0=1, extra_dim=0, embed_concat_pts=True, perturb_normal=False, symmetrize=False, texture_way=None, cam_pos_z_offset=10., grid_scale=7., local_feat_dim=128, grid_size=32, optim_latent=False): super().__init__() self.extra_dim = extra_dim self.cam_pos_z_offset = cam_pos_z_offset self.grid_scale = grid_scale assert texture_way is not None self.extra_dim = local_feat_dim + extra_dim if n_harmonic_functions > 0: self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0) dim_in = cin * 2 * n_harmonic_functions self.embed_concat_pts = embed_concat_pts if embed_concat_pts: dim_in += cin else: self.embedder = None dim_in = cin self.encoder = Lift_Encoder( cin=384, feat_dim=local_feat_dim, grid_scale=grid_scale / 2, # the dmtet is initialized in (-0.5, 0.5) grid_size=grid_size, optim_latent=optim_latent, with_z_feature=True, cam_pos_z_offset=cam_pos_z_offset ) self.in_layer = nn.Linear(dim_in, nf) self.relu = nn.ReLU(inplace=True) self.mlp = MLP(nf + self.extra_dim, cout, num_layers, nf, dropout, activation) self.perturb_normal = perturb_normal self.symmetrize = symmetrize if min_max is not None: self.register_buffer('min_max', min_max) else: self.min_max = None self.bsdf = None def get_uv_depth(self, xyz, mvp): # xyz: [b, k, 3] # mvp: [b, 4, 4] cam4 = torch.matmul(torch.nn.functional.pad(xyz, pad=(0,1), mode='constant', value=1.0), torch.transpose(mvp, 1, 2)) cam3 = cam4[..., :3] / cam4[..., 3:4] cam_uv = cam3[..., :2] # cam_uv = cam_uv.detach() cam_depth = cam3 + torch.FloatTensor([0, 0, self.cam_pos_z_offset]).to(xyz.device).view(1, 1, 3) cam_depth = cam_depth / self.grid_scale * 2 cam_depth = cam_depth[..., 2:3] # cam_depth = cam_depth.detach() return cam_uv, cam_depth def proj_sample_deform(self, xyz, feat_map, mvp, w2c, img_h, img_w): # here the xyz is deformed points # and we don't cast any symmtery here if isinstance(feat_map, torch.Tensor): feature = self.encoder(feat_map, mvp, xyz, inference="unproject") elif isinstance(feat_map, dict): feature = self.encoder(feat_map['im_features_map'], mvp, xyz, inference="sample") C = feature.shape[-1] feature = feature.reshape(-1, C) return feature def sample(self, x, feat=None, feat_map=None, mvp=None, w2c=None, deform_xyz=None): # assert (feat is None and self.extra_dim == 0) or (feat.shape[-1] <= self.extra_dim) b, h, w, c = x.shape xyz_before_sym = None if self.symmetrize: xyz_before_sym = x.reshape(b, -1, c) xs, ys, zs = x.unbind(-1) x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x mvp = mvp.detach() # [b, 4, 4] w2c = w2c.detach() # [b, 4, 4] pts_xyz = x.reshape(b, -1, c) deform_xyz = deform_xyz.reshape(b, -1, c) global_feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1) local_feat = self.proj_sample_deform(deform_xyz, feat_map, mvp, w2c, h, w) feature_rep = torch.concat([global_feat, local_feat], dim=-1) x = x.view(-1, c) if self.embedder is not None: x_in = self.embedder(x) if self.embed_concat_pts: x_in = torch.cat([x, x_in], -1) else: x_in = x x_in = self.in_layer(x_in) # if feat is not None: # feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1) # x_in = torch.concat([x_in, feat], dim=-1) x_in = torch.concat([x_in, feature_rep], dim=-1) out = self.mlp(self.relu(x_in)) if self.min_max is not None: out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :] return out.view(b, h, w, -1) class HarmonicEmbedding(nn.Module): def __init__(self, n_harmonic_functions=10, omega0=1): """ Positional Embedding implementation (adapted from Pytorch3D). Given an input tensor `x` of shape [minibatch, ... , dim], the harmonic embedding layer converts each feature in `x` into a series of harmonic features `embedding` as follows: embedding[..., i*dim:(i+1)*dim] = [ sin(x[..., i]), sin(2*x[..., i]), sin(4*x[..., i]), ... sin(2**self.n_harmonic_functions * x[..., i]), cos(x[..., i]), cos(2*x[..., i]), cos(4*x[..., i]), ... cos(2**self.n_harmonic_functions * x[..., i]) ] Note that `x` is also premultiplied by `omega0` before evaluting the harmonic functions. """ super().__init__() self.frequencies = omega0 * (2.0 ** torch.arange(n_harmonic_functions)) def forward(self, x): """ Args: x: tensor of shape [..., dim] Returns: embedding: a harmonic embedding of `x` of shape [..., n_harmonic_functions * dim * 2] """ embed = (x[..., None] * self.frequencies.to(x.device)).view(*x.shape[:-1], -1) return torch.cat((embed.sin(), embed.cos()), dim=-1) class VGGEncoder(nn.Module): def __init__(self, cout, pretrained=False): super().__init__() if pretrained: raise NotImplementedError vgg = models.vgg16() self.vgg_encoder = nn.Sequential(vgg.features, vgg.avgpool) self.linear1 = nn.Linear(25088, 4096) self.linear2 = nn.Linear(4096, cout) self.relu = nn.ReLU(inplace=True) def forward(self, x): batch_size, _, _, _ = x.shape out = self.relu(self.linear1(self.vgg_encoder(x).view(batch_size, -1))) return self.linear2(out) class ResnetEncoder(nn.Module): def __init__(self, cout, pretrained=False): super().__init__() self.resnet = nn.Sequential(list(models.resnet18(weights="DEFAULT" if pretrained else None).modules())[:-1]) self.final_linear = nn.Linear(512, cout) def forward(self, x): return self.final_linear(self.resnet(x)) class Encoder(nn.Module): def __init__(self, cin, cout, in_size=128, zdim=None, nf=64, activation=None): super().__init__() network = [ nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 128x128 -> 64x64 nn.GroupNorm(16, nf), # nn.ReLU(inplace=True), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32 nn.GroupNorm(16*2, nf*2), # nn.ReLU(inplace=True), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16 nn.GroupNorm(16*4, nf*4), # nn.ReLU(inplace=True), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 # nn.GroupNorm(16*8, nf*8), # nn.ReLU(inplace=True), nn.LeakyReLU(0.2, inplace=True), ] add_downsample = int(np.log2(in_size//128)) if add_downsample > 0: for _ in range(add_downsample): network += [ nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 # nn.GroupNorm(16*8, nf*8), # nn.ReLU(inplace=True), nn.LeakyReLU(0.2, inplace=True), ] network += [ nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4 nn.LeakyReLU(0.2, inplace=True), ] if zdim is None: network += [ nn.Conv2d(nf*8, cout, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 ] else: network += [ nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 # nn.ReLU(inplace=True), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(zdim, cout, kernel_size=1, stride=1, padding=0, bias=False), ] if activation is not None: network += [get_activation(activation)] self.network = nn.Sequential(*network) def forward(self, input): return self.network(input).reshape(input.size(0), -1) class EncoderWithDINO(nn.Module): def __init__(self, cin_rgb, cin_dino, cout, in_size=128, zdim=None, nf=64, activation=None): super().__init__() network_rgb_in = [ nn.Conv2d(cin_rgb, nf, kernel_size=4, stride=2, padding=1, bias=False), # 128x128 -> 64x64 nn.GroupNorm(16, nf), # nn.ReLU(inplace=True), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32 nn.GroupNorm(16*2, nf*2), # nn.ReLU(inplace=True), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16 nn.GroupNorm(16*4, nf*4), # nn.ReLU(inplace=True), nn.LeakyReLU(0.2, inplace=True), ] self.network_rgb_in = nn.Sequential(*network_rgb_in) network_dino_in = [ nn.Conv2d(cin_dino, nf, kernel_size=4, stride=2, padding=1, bias=False), # 128x128 -> 64x64 nn.GroupNorm(16, nf), # nn.ReLU(inplace=True), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32 nn.GroupNorm(16*2, nf*2), # nn.ReLU(inplace=True), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16 nn.GroupNorm(16*4, nf*4), # nn.ReLU(inplace=True), nn.LeakyReLU(0.2, inplace=True), ] self.network_dino_in = nn.Sequential(*network_dino_in) network_fusion = [ nn.Conv2d(nf*4*2, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 # nn.GroupNorm(16*8, nf*8), # nn.ReLU(inplace=True), nn.LeakyReLU(0.2, inplace=True), ] add_downsample = int(np.log2(in_size//128)) if add_downsample > 0: for _ in range(add_downsample): network_fusion += [ nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 # nn.GroupNorm(16*8, nf*8), # nn.ReLU(inplace=True), nn.LeakyReLU(0.2, inplace=True), ] network_fusion += [ nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4 nn.LeakyReLU(0.2, inplace=True), ] if zdim is None: network_fusion += [ nn.Conv2d(nf*8, cout, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 ] else: network_fusion += [ nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 # nn.ReLU(inplace=True), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(zdim, cout, kernel_size=1, stride=1, padding=0, bias=False), ] if activation is not None: network_fusion += [get_activation(activation)] self.network_fusion = nn.Sequential(*network_fusion) def forward(self, rgb_image, dino_image): rgb_feat = self.network_rgb_in(rgb_image) dino_feat = self.network_dino_in(dino_image) out = self.network_fusion(torch.cat([rgb_feat, dino_feat], dim=1)) return out.reshape(rgb_image.size(0), -1) class Encoder32(nn.Module): def __init__(self, cin, cout, nf=256, activation=None): super().__init__() network = [ nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16 nn.GroupNorm(nf//4, nf), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(nf, nf, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 nn.GroupNorm(nf//4, nf), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(nf, nf, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4 nn.GroupNorm(nf//4, nf), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(nf, cout, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 ] if activation is not None: network += [get_activation(activation)] self.network = nn.Sequential(*network) def forward(self, input): return self.network(input).reshape(input.size(0), -1) class MLP(nn.Module): def __init__(self, cin, cout, num_layers, nf=256, dropout=0, activation=None, inner_act='relu', linear_bias=False): super().__init__() assert num_layers >= 1 layer_act = get_activation(inner_act) if num_layers == 1: network = [nn.Linear(cin, cout, bias=linear_bias)] else: # network = [nn.Linear(cin, nf, bias=False)] # for _ in range(num_layers-2): # network += [ # nn.ReLU(inplace=True), # nn.Linear(nf, nf, bias=False)] # if dropout: # network += [nn.Dropout(dropout)] # network += [ # nn.ReLU(inplace=True), # nn.Linear(nf, cout, bias=False)] network = [nn.Linear(cin, nf, bias=linear_bias)] for _ in range(num_layers-2): network += [ layer_act, nn.Linear(nf, nf, bias=linear_bias)] if dropout: network += [nn.Dropout(dropout)] network += [ layer_act, nn.Linear(nf, cout, bias=linear_bias)] if activation is not None: network += [get_activation(activation)] self.network = nn.Sequential(*network) def forward(self, input): return self.network(input) class Embedding(nn.Module): def __init__(self, cin, cout, zdim=128, nf=64, activation=None): super().__init__() network = [ nn.Linear(cin, nf, bias=False), nn.ReLU(inplace=True), nn.Linear(nf, zdim, bias=False), nn.ReLU(inplace=True), nn.Linear(zdim, cout, bias=False)] if activation is not None: network += [get_activation(activation)] self.network = nn.Sequential(*network) def forward(self, input): return self.network(input.reshape(input.size(0), -1)).reshape(input.size(0), -1) class PerceptualLoss(nn.Module): def __init__(self, requires_grad=False): super(PerceptualLoss, self).__init__() mean_rgb = torch.FloatTensor([0.485, 0.456, 0.406]) std_rgb = torch.FloatTensor([0.229, 0.224, 0.225]) self.register_buffer('mean_rgb', mean_rgb) self.register_buffer('std_rgb', std_rgb) vgg_pretrained_features = torchvision.models.vgg16(pretrained=True).features self.slice1 = nn.Sequential() self.slice2 = nn.Sequential() self.slice3 = nn.Sequential() self.slice4 = nn.Sequential() for x in range(4): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(4, 9): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(9, 16): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(16, 23): self.slice4.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def normalize(self, x): out = x/2 + 0.5 out = (out - self.mean_rgb.view(1,3,1,1)) / self.std_rgb.view(1,3,1,1) return out def __call__(self, im1, im2, mask=None, conf_sigma=None): im = torch.cat([im1,im2], 0) im = self.normalize(im) # normalize input ## compute features feats = [] f = self.slice1(im) feats += [torch.chunk(f, 2, dim=0)] f = self.slice2(f) feats += [torch.chunk(f, 2, dim=0)] f = self.slice3(f) feats += [torch.chunk(f, 2, dim=0)] f = self.slice4(f) feats += [torch.chunk(f, 2, dim=0)] losses = [] for f1, f2 in feats[2:3]: # use relu3_3 features only loss = (f1-f2)**2 if conf_sigma is not None: loss = loss / (2*conf_sigma**2 +EPS) + (conf_sigma +EPS).log() if mask is not None: b, c, h, w = loss.shape _, _, hm, wm = mask.shape sh, sw = hm//h, wm//w mask0 = nn.functional.avg_pool2d(mask, kernel_size=(sh,sw), stride=(sh,sw)).expand_as(loss) loss = (loss * mask0).sum() / mask0.sum() else: loss = loss.mean() losses += [loss] return sum(losses) ## from: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(BasicBlock, self).__init__() if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.norm_layer = norm_layer if norm_layer is not None: self.bn1 = norm_layer(planes) self.bn2 = norm_layer(planes) if inplanes != planes: self.downsample = nn.Sequential( conv1x1(inplanes, planes, stride), norm_layer(planes), ) else: self.downsample = None self.stride = stride def forward(self, x): identity = x out = self.conv1(x) if self.norm_layer is not None: out = self.bn1(out) out = self.relu(out) out = self.conv2(out) if self.norm_layer is not None: out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class ResEncoder(nn.Module): def __init__(self, cin, cout, in_size=128, zdim=None, nf=64, activation=None): super().__init__() network = [ nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 128x128 -> 64x64 # nn.GroupNorm(16, nf), # nn.ReLU(inplace=True), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32 # nn.GroupNorm(16*2, nf*2), # nn.ReLU(inplace=True), nn.LeakyReLU(0.2, inplace=True), BasicBlock(nf*2, nf*2, norm_layer=None), BasicBlock(nf*2, nf*2, norm_layer=None), nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16 # nn.GroupNorm(16*4, nf*4), # nn.ReLU(inplace=True), nn.LeakyReLU(0.2, inplace=True), BasicBlock(nf*4, nf*4, norm_layer=None), BasicBlock(nf*4, nf*4, norm_layer=None), nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 # nn.ReLU(inplace=True), nn.LeakyReLU(0.2, inplace=True), BasicBlock(nf*8, nf*8, norm_layer=None), BasicBlock(nf*8, nf*8, norm_layer=None), ] add_downsample = int(np.log2(in_size//64)) if add_downsample > 0: for _ in range(add_downsample): network += [ nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4 # nn.ReLU(inplace=True), nn.LeakyReLU(0.2, inplace=True), BasicBlock(nf*8, nf*8, norm_layer=None), BasicBlock(nf*8, nf*8, norm_layer=None), ] if zdim is None: network += [ nn.Conv2d(nf*8, cout, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 ] else: network += [ nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 # nn.ReLU(inplace=True), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(zdim, cout, kernel_size=1, stride=1, padding=0, bias=False), ] if activation is not None: network += [get_activation(activation)] self.network = nn.Sequential(*network) def forward(self, input): return self.network(input).reshape(input.size(0), -1) class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x, attn class ViTEncoder(nn.Module): def __init__(self, cout, which_vit='dino_vits8', pretrained=False, frozen=False, in_size=256, final_layer_type='none', root='/root'): super().__init__() if misc.is_main_process(): force_reload = not os.path.exists(os.path.join(root, ".cache/torch/hub/checkpoints/")) else: force_reload = False if "dinov2" in which_vit: self.ViT = torch.hub.load('facebookresearch/dinov2:main', which_vit, pretrained=pretrained, force_reload=force_reload) else: self.ViT = torch.hub.load('facebookresearch/dino:main', which_vit, pretrained=pretrained, force_reload=force_reload) if frozen: for p in self.ViT.parameters(): p.requires_grad = False if which_vit == 'dino_vits8': self.vit_feat_dim = 384 self.patch_size = 8 elif which_vit == 'dinov2_vits14': self.vit_feat_dim = 384 self.patch_size = 14 elif which_vit == 'dino_vitb8': self.vit_feat_dim = 768 self.patch_size = 8 self._feats = [] self.hook_handlers = [] if final_layer_type == 'none': pass elif final_layer_type == 'conv': self.final_layer_patch_out = Encoder32(self.vit_feat_dim, cout, nf=256, activation=None) self.final_layer_patch_key = Encoder32(self.vit_feat_dim, cout, nf=256, activation=None) elif final_layer_type == 'attention': raise NotImplementedError self.final_layer = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.fc = nn.Linear(self.vit_feat_dim, cout) else: raise NotImplementedError self.final_layer_type = final_layer_type def _get_hook(self, facet: str): """ generate a hook method for a specific block and facet. """ if facet in ['attn', 'token']: def _hook(model, input, output): self._feats.append(output) return _hook if facet == 'query': facet_idx = 0 elif facet == 'key': facet_idx = 1 elif facet == 'value': facet_idx = 2 else: raise TypeError(f"{facet} is not a supported facet.") def _inner_hook(module, input, output): input = input[0] B, N, C = input.shape qkv = module.qkv(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4) self._feats.append(qkv[facet_idx]) #Bxhxtxd return _inner_hook def _register_hooks(self, layers: List[int], facet: str) -> None: """ register hook to extract features. :param layers: layers from which to extract features. :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn'] """ for block_idx, block in enumerate(self.ViT.blocks): if block_idx in layers: if facet == 'token': self.hook_handlers.append(block.register_forward_hook(self._get_hook(facet))) elif facet == 'attn': self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_hook(facet))) elif facet in ['key', 'query', 'value']: self.hook_handlers.append(block.attn.register_forward_hook(self._get_hook(facet))) else: raise TypeError(f"{facet} is not a supported facet.") def _unregister_hooks(self) -> None: """ unregisters the hooks. should be called after feature extraction. """ for handle in self.hook_handlers: handle.remove() self.hook_handlers = [] def forward(self, x, return_patches=False): b, c, h, w = x.shape self._feats = [] self._register_hooks([11], 'key') #self._register_hooks([11], 'token') x = self.ViT.prepare_tokens(x) #x = self.ViT.prepare_tokens_with_masks(x) for blk in self.ViT.blocks: x = blk(x) out = self.ViT.norm(x) self._unregister_hooks() ph, pw = h // self.patch_size, w // self.patch_size patch_out = out[:, 1:] # first is class token patch_out = patch_out.reshape(b, ph, pw, self.vit_feat_dim).permute(0, 3, 1, 2) patch_key = self._feats[0][:,:,1:] # B, num_heads, num_patches, dim patch_key = patch_key.permute(0, 1, 3, 2).reshape(b, self.vit_feat_dim, ph, pw) if self.final_layer_type == 'none': global_feat_out = out[:, 0].reshape(b, -1) # first is class token global_feat_key = self._feats[0][:, :, 0].reshape(b, -1) # first is class token elif self.final_layer_type == 'conv': global_feat_out = self.final_layer_patch_out(patch_out).view(b, -1) global_feat_key = self.final_layer_patch_key(patch_key).view(b, -1) elif self.final_layer_type == 'attention': raise NotImplementedError else: raise NotImplementedError if not return_patches: patch_out = patch_key = None return global_feat_out, global_feat_key, patch_out, patch_key class ArticulationNetwork(nn.Module): def __init__(self, net_type, feat_dim, pos_dim, num_layers, nf, n_harmonic_functions=0, omega0=1, activation=None, enable_articulation_idadd=False): super().__init__() if n_harmonic_functions > 0: self.posenc = HarmonicEmbedding(n_harmonic_functions=n_harmonic_functions, omega0=omega0) pos_dim = pos_dim * (n_harmonic_functions * 2 + 1) else: self.posenc = None pos_dim = 4 cout = 3 if net_type == 'mlp': self.network = MLP( feat_dim + pos_dim, # + bone xyz pos and index cout, # We represent the rotation of each bone by its Euler angles ψ, θ, and φ num_layers, nf=nf, dropout=0, activation=activation ) elif net_type == 'attention': self.in_layer = nn.Sequential( nn.Linear(feat_dim + pos_dim, nf), nn.GELU(), nn.LayerNorm(nf), ) self.blocks = nn.ModuleList([ Block( dim=nf, num_heads=8, mlp_ratio=2., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm) for i in range(num_layers)]) out_layer = [nn.Linear(nf, cout)] if activation: out_layer += [get_activation(activation)] self.out_layer = nn.Sequential(*out_layer) else: raise NotImplementedError self.net_type = net_type self.enable_articulation_idadd = enable_articulation_idadd def forward(self, x, pos): pos_inp = pos if self.posenc is not None: pos = torch.cat([pos, self.posenc(pos)], dim=-1) x = torch.cat([x, pos], dim=-1) if self.enable_articulation_idadd: articulation_id = pos_inp[..., -1:] x = x + articulation_id if self.net_type == 'mlp': out = self.network(x) elif self.net_type == 'attention': x = self.in_layer(x) for blk in self.blocks: x = blk(x) out = self.out_layer(x) else: raise NotImplementedError return out ## Attention block from ViT (https://github.com/facebookresearch/dino/blob/main/vision_transformer.py) class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x, attn class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x, return_attention=False): y, attn = self.attn(self.norm1(x)) if return_attention: return attn x = x + self.drop_path(y) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class FeatureAttention(nn.Module): def __init__(self, vit_type, pos_dim, embedder_freq=0, zdim=128, img_size=256, activation=None): super().__init__() self.zdim = zdim if embedder_freq > 0: self.posenc = HarmonicEmbedding(n_harmonic_functions=embedder_freq, omega0=1) pos_dim = pos_dim * (embedder_freq * 2 + 1) else: self.posenc = None self.pos_dim = pos_dim if vit_type == 'dino_vits8': self.vit_feat_dim = 384 patch_size = 8 elif which_vit == 'dinov2_vits14': self.vit_feat_dim = 384 self.patch_size = 14 elif vit_type == 'dino_vitb8': self.vit_feat_dim = 768 patch_size = 8 else: raise NotImplementedError self.num_patches_per_dim = img_size // patch_size self.kv = nn.Sequential( nn.Linear(self.vit_feat_dim, zdim), nn.ReLU(inplace=True), nn.LayerNorm(zdim), nn.Linear(zdim, zdim*2), ) self.q = nn.Sequential( nn.Linear(pos_dim, zdim), nn.ReLU(inplace=True), nn.LayerNorm(zdim), nn.Linear(zdim, zdim), ) final_mlp = [ nn.Linear(zdim, zdim), nn.ReLU(inplace=True), nn.LayerNorm(zdim), nn.Linear(zdim, self.vit_feat_dim) ] if activation is not None: final_mlp += [get_activation(activation)] self.final_ln = nn.Sequential(*final_mlp) def forward(self, x, feat): _, vit_feat_dim, ph, pw = feat.shape assert ph == pw and ph == self.num_patches_per_dim and vit_feat_dim == self.vit_feat_dim if self.posenc is not None: x = torch.cat([x, self.posenc(x)], dim=-1) bxf, k, c = x.shape assert c == self.pos_dim query = self.q(x) feat_in = feat.view(bxf, vit_feat_dim, ph*pw).permute(0, 2, 1) # N, K, C k, v = self.kv(feat_in).chunk(2, dim=-1) attn = torch.einsum('bnd,bpd->bnp', query, k).softmax(dim=-1) out = torch.einsum('bnp,bpd->bnd', attn, v) out = self.final_ln(out) return out