3DFauna_demo / video3d /triplane_texture /triplane_transformer.py
kyleleey
first commit
98a77e0
raw
history blame
No virus
5.74 kB
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
from typing import Union, List, Tuple
import os
import video3d.utils.misc as misc
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Transformer_layer(nn.Module):
def __init__(self, dim_feat=384, dim=1024, hidden_dim=1024, heads=16):
super().__init__()
'''
dim: the dim between each attention, mlp, also the input and output dim for the layer
hidden_dim: the dim inside qkv
dim_feat: condition feature dim
'''
dim_head = hidden_dim // heads
self.heads = heads
self.scale = dim_head ** -0.5 # 8
self.norm = nn.LayerNorm(dim)
self.ffn = FeedForward(
dim=dim,
hidden_dim=(4 * dim),
dropout=0.
)
# cross attention part
self.to_cross_q = nn.Linear(dim, hidden_dim, bias=False)
self.to_cross_kv = nn.Linear(dim_feat, hidden_dim*2, bias=False)
self.cross_attend = nn.Softmax(dim=-1)
# self attention part
self.to_self_qkv = nn.Linear(dim, hidden_dim*3, bias=False)
self.self_attend = nn.Softmax(dim=-1)
def forward_cross_attn(self, x, feature):
x = self.norm(x)
q = self.to_cross_q(x)
k, v = self.to_cross_kv(feature).chunk(2, dim=-1)
qkv = (q, k, v)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.cross_attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return out
def forward_self_attn(self, x):
x = self.norm(x)
qkv = self.to_self_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.self_attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return out
def forward(self, x, feature):
'''
x: [B, N, dim]
feature: [B, N, dim_feat]
'''
cross_token = self.forward_cross_attn(x, feature)
cross_token = cross_token + x
self_token = self.forward_self_attn(cross_token)
self_token = self_token + cross_token
out = self.ffn(self_token)
out = out + self_token
return out
class Triplane_Transformer(nn.Module):
def __init__(self, emb_dim=1024, emb_num=1024, num_layers=16,
triplane_dim=80, triplane_scale=7.):
super().__init__()
self.learnable_embedding = nn.Parameter(torch.randn(1, emb_num, emb_dim))
self.layers = nn.ModuleList([])
for _ in range(num_layers):
self.layers.append(
Transformer_layer(
dim_feat=384,
dim=emb_dim,
hidden_dim=emb_dim
)
)
self.triplane_dim = triplane_dim
self.triplane_scale = triplane_scale
self.to_triplane = nn.ConvTranspose2d(
in_channels=emb_dim,
out_channels=3 * triplane_dim,
kernel_size=4,
padding=1,
stride=2
)
self.norm = nn.LayerNorm(emb_dim)
def sample_feat(self, feat_maps, pts):
'''
feat_maps: [B, 3, C, H, W]
pts: [B, K, 3]
'''
pts = pts / (self.triplane_scale / 2)
pts_xy = pts[..., [0,1]]
pts_yz = pts[..., [1,2]]
pts_xz = pts[..., [0,2]]
feat_xy = feat_maps[:, 0, :, :, :]
feat_yz = feat_maps[:, 1, :, :, :]
feat_xz = feat_maps[:, 2, :, :, :]
sampled_feat_xy = F.grid_sample(
feat_xy, pts_xy.unsqueeze(1), mode='bilinear', align_corners=True
)
sampled_feat_yz = F.grid_sample(
feat_yz, pts_yz.unsqueeze(1), mode='bilinear', align_corners=True
)
sampled_feat_xz = F.grid_sample(
feat_xz, pts_xz.unsqueeze(1), mode='bilinear', align_corners=True
)
sampled_feat = torch.cat([sampled_feat_xy, sampled_feat_yz, sampled_feat_xz], dim=1).squeeze(-2) # [B, F, K]
sampled_feat = sampled_feat.permute(0, 2, 1)
return sampled_feat
def forward(self, feature, pts):
'''
feature: [B, N, dim_feat]
'''
batch_size = feature.shape[0]
embedding = self.learnable_embedding.repeat(batch_size, 1, 1)
x = embedding
for layer in self.layers:
x = layer(x, feature)
x = self.norm(x)
# x: [B, 32x32, 1024]
batch_size, pwph, feat_dim = x.shape
ph = int(pwph ** 0.5)
pw = int(pwph ** 0.5)
triplane_feat = x.reshape(batch_size, ph, pw, feat_dim).permute(0, 3, 1, 2)
triplane_feat = self.to_triplane(triplane_feat) # [B, C, 64, 64]
triplane_feat = triplane_feat.reshape(triplane_feat.shape[0], 3, self.triplane_dim, triplane_feat.shape[-2], triplane_feat.shape[-1])
pts_feat = self.sample_feat(triplane_feat, pts)
return pts_feat