kyleleey
first commit
98a77e0
raw
history blame contribute delete
No virus
4.01 kB
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import torch
from ..render import mesh
from ..render import render
from ..render import regularizer
###############################################################################
# Geometry interface
###############################################################################
class DLMesh(torch.nn.Module):
def __init__(self, initial_guess, FLAGS):
super(DLMesh, self).__init__()
self.FLAGS = FLAGS
self.initial_guess = initial_guess
self.mesh = initial_guess.clone()
print("Base mesh has %d triangles and %d vertices." % (self.mesh.t_pos_idx.shape[0], self.mesh.v_pos.shape[0]))
self.mesh.v_pos = torch.nn.Parameter(self.mesh.v_pos, requires_grad=True)
self.register_parameter('vertex_pos', self.mesh.v_pos)
@torch.no_grad()
def getAABB(self):
return mesh.aabb(self.mesh)
def getMesh(self, material):
self.mesh.material = material
imesh = mesh.Mesh(base=self.mesh)
# Compute normals and tangent space
imesh = mesh.auto_normals(imesh)
imesh = mesh.compute_tangents(imesh)
return imesh
def render(self, glctx, target, lgt, opt_material, bsdf=None):
opt_mesh = self.getMesh(opt_material)
return render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
num_layers=self.FLAGS.layers, msaa=True, background=target['background'], bsdf=bsdf)
def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration):
# ==============================================================================================
# Render optimizable object with identical conditions
# ==============================================================================================
buffers = self.render(glctx, target, lgt, opt_material)
# ==============================================================================================
# Compute loss
# ==============================================================================================
t_iter = iteration / self.FLAGS.iter
# Image-space loss, split into a coverage component and a color component
color_ref = target['img']
img_loss = torch.nn.functional.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:])
img_loss += loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:])
reg_loss = torch.tensor([0], dtype=torch.float32, device="cuda")
# Compute regularizer.
if self.FLAGS.laplace == "absolute":
reg_loss += regularizer.laplace_regularizer_const(self.mesh.v_pos, self.mesh.t_pos_idx) * self.FLAGS.laplace_scale * (1 - t_iter)
elif self.FLAGS.laplace == "relative":
reg_loss += regularizer.laplace_regularizer_const(self.mesh.v_pos - self.initial_guess.v_pos, self.mesh.t_pos_idx) * self.FLAGS.laplace_scale * (1 - t_iter)
# Albedo (k_d) smoothnesss regularizer
reg_loss += torch.mean(buffers['kd_grad'][..., :-1] * buffers['kd_grad'][..., -1:]) * 0.03 * min(1.0, iteration / 500)
# Visibility regularizer
reg_loss += torch.mean(buffers['occlusion'][..., :-1] * buffers['occlusion'][..., -1:]) * 0.001 * min(1.0, iteration / 500)
# Light white balance regularizer
reg_loss = reg_loss + lgt.regularizer() * 0.005
return img_loss, reg_loss