3DFauna_demo / video3d /triplane_texture /lift_architecture.py
kyleleey
first commit
98a77e0
raw
history blame
No virus
7 kB
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
from typing import Union, List, Tuple
import os
import video3d.utils.misc as misc
import torch.nn.functional as F
class Lift_Encoder(nn.Module):
def __init__(
self,
cin,
feat_dim,
grid_scale=7.,
grid_size=32,
optim_latent=False,
img_size=256,
with_z_feature=False,
cam_pos_z_offset=10.
):
super().__init__()
'''
unproject the input feature map to tri-plane, each plane is (-1, -1)*grid_scale to (1, 1)*scale
'''
self.cin = cin
self.nf = feat_dim
self.grid_scale = grid_scale
self.grid_size = grid_size
self.img_size = img_size
self.with_z_feature = with_z_feature
self.cam_pos_z_offset = cam_pos_z_offset
self.feature_projector = nn.Linear(cin, feat_dim, bias=False)
self.plane_latent = None
if optim_latent:
self.optim_latent = nn.Parameter(torch.rand(3, feat_dim, grid_size, grid_size))
else:
self.optim_latent = None
if with_z_feature:
self.conv_bottleneck = nn.Conv2d(feat_dim+1, feat_dim, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate")
else:
self.conv_bottleneck = nn.Conv2d(feat_dim, feat_dim, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate")
#TODO: implement an upsampler for input feature map here?
self.conv_1 = nn.Conv2d(feat_dim, 4*feat_dim, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate")
self.conv_2 = nn.Conv2d(feat_dim, 4*feat_dim, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate")
self.up = nn.PixelShuffle(2)
self.conv_enc = nn.Conv2d(feat_dim, feat_dim // 2, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate")
self.conv_dec = nn.Conv2d(feat_dim // 2, feat_dim, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate")
self.feature_fusion = nn.Linear(3*feat_dim, feat_dim, bias=False)
def get_coords(self, grid_size):
with torch.no_grad():
lines = torch.arange(0, grid_size)
grids_x, grids_y = torch.meshgrid([lines, lines], indexing="ij")
grids = torch.stack([grids_x, grids_y], dim=-1)
grids = (grids - self.grid_size // 2) / (self.grid_size // 2)
grids = grids * self.grid_scale
plane_z0 = torch.cat([grids, torch.zeros(list(grids.shape[:-1]) + [1])], dim=-1) # [S, S, 3]
plane_y0 = plane_z0.clone()[..., [0, 2, 1]]
plane_x0 = plane_z0.clone()[..., [2, 0, 1]]
planes = torch.stack([plane_x0, plane_y0, plane_z0], dim=0)
return planes # [3, S, S, 3]
def get_uv_z(self, pts, mvp):
cam4 = torch.matmul(torch.nn.functional.pad(pts, pad=(0,1), mode='constant', value=1.0), torch.transpose(mvp, 1, 2))
cam3 = cam4[..., :3] / cam4[..., 3:4]
cam_uv = cam3[..., :2]
# cam_uv = cam_uv.detach()
cam_depth = cam3 + torch.FloatTensor([0, 0, self.cam_pos_z_offset]).to(pts.device).view(1, 1, 3)
cam_depth = cam_depth / self.grid_scale * 2
cam_depth = cam_depth[..., 2:3]
return cam_uv, cam_depth
def unproject(self, feature_map, mvp):
'''
feature_map: [B, C, h, w]
mvp: [B, 4, 4]
'''
self.plane_latent = None
bs, C, h, w = feature_map.shape
device = feature_map.device
feature_map = self.feature_projector(feature_map.permute(0, 2, 3, 1).reshape(-1, C)).reshape(bs, h, w, self.nf).permute(0, 3, 1, 2)
feature_map = self.up(self.conv_1(feature_map))
feature_map = self.up(self.conv_2(feature_map))
plane_coords = self.get_coords(self.grid_size)
plane_coords = plane_coords.unsqueeze(0).repeat(bs, 1, 1, 1, 1)
plane_coords = plane_coords.to(device)
plane_pts = plane_coords.reshape(bs, -1, 3) # [B, N_POINTS, 3]
plane_uv, plane_z = self.get_uv_z(plane_pts, mvp)
plane_uv = plane_uv.detach()
plane_z = plane_z.detach()
nP = plane_pts.shape[1]
plane_feature = F.grid_sample(feature_map, plane_uv.reshape(bs, 1, nP, 2), mode="bilinear", padding_mode="zeros").squeeze(dim=-2).permute(0, 2, 1)
if self.with_z_feature:
plane_feature = torch.cat([plane_feature, plane_z], dim=-1)
plane_feature = plane_feature.reshape(plane_feature.shape[0], 3, self.grid_size, self.grid_size, plane_feature.shape[-1])
return plane_feature
def conv_plane(self, plane_feature):
bs, _, nh, nw, nC = plane_feature.shape
plane_feature = plane_feature.reshape(-1, nh, nw, nC).permute(0, 3, 1, 2) # [bs*3, nC, nh, nw]
plane_feature = self.conv_bottleneck(plane_feature)
x = self.conv_dec(self.conv_enc(plane_feature))
out = x + plane_feature
out = out.reshape(bs, 3, out.shape[-3], out.shape[-2], out.shape[-1])
if self.optim_latent is not None:
optim_latent = self.optim_latent.unsqueeze(0).repeat(bs, 1, 1, 1, 1)
out = out + optim_latent
return out
def sample_plane(self, pts, feat):
'''
pts: [B, K, 3]
feat: [B, 3, C, h, w]
'''
pts_x, pts_y, pts_z = pts.unbind(dim=-1)
pts_x0 = torch.stack([pts_y, pts_z], dim=-1)
pts_y0 = torch.stack([pts_x, pts_z], dim=-1)
pts_z0 = torch.stack([pts_x, pts_y], dim=-1)
feat_x0 = F.grid_sample(feat[:, 0, :, :], pts_x0.unsqueeze(1), mode="bilinear", padding_mode="border").squeeze(-2).permute(0, 2, 1)
feat_y0 = F.grid_sample(feat[:, 0, :, :], pts_y0.unsqueeze(1), mode="bilinear", padding_mode="border").squeeze(-2).permute(0, 2, 1)
feat_z0 = F.grid_sample(feat[:, 0, :, :], pts_z0.unsqueeze(1), mode="bilinear", padding_mode="border").squeeze(-2).permute(0, 2, 1)
pts_feat = torch.cat([feat_x0, feat_y0, feat_z0], dim=-1)
return pts_feat
def forward(self, feature_map, mvp, pts, inference="unproject"):
'''
inference = "unproject" or "sample"
'''
assert inference in ["unproject", "sample"]
if inference == "unproject":
plane_feature = self.unproject(feature_map, mvp)
plane_feature = self.conv_plane(plane_feature)
self.plane_latent = plane_feature.clone().detach() # this is just for test case
if inference == "unproject":
feat_to_sample = plane_feature
else:
new_bs = pts.shape[0]
feat_to_sample = self.plane_latent[:new_bs]
pts_feature = self.sample_plane(pts, feat_to_sample)
pts_feature = self.feature_fusion(pts_feature) # [B, K, C]
return pts_feature