# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import torch from ..render import mesh from ..render import render from ..render import regularizer ############################################################################### # Geometry interface ############################################################################### class DLMesh(torch.nn.Module): def __init__(self, initial_guess, FLAGS): super(DLMesh, self).__init__() self.FLAGS = FLAGS self.initial_guess = initial_guess self.mesh = initial_guess.clone() print("Base mesh has %d triangles and %d vertices." % (self.mesh.t_pos_idx.shape[0], self.mesh.v_pos.shape[0])) self.mesh.v_pos = torch.nn.Parameter(self.mesh.v_pos, requires_grad=True) self.register_parameter('vertex_pos', self.mesh.v_pos) @torch.no_grad() def getAABB(self): return mesh.aabb(self.mesh) def getMesh(self, material): self.mesh.material = material imesh = mesh.Mesh(base=self.mesh) # Compute normals and tangent space imesh = mesh.auto_normals(imesh) imesh = mesh.compute_tangents(imesh) return imesh def render(self, glctx, target, lgt, opt_material, bsdf=None): opt_mesh = self.getMesh(opt_material) return render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], num_layers=self.FLAGS.layers, msaa=True, background=target['background'], bsdf=bsdf) def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration): # ============================================================================================== # Render optimizable object with identical conditions # ============================================================================================== buffers = self.render(glctx, target, lgt, opt_material) # ============================================================================================== # Compute loss # ============================================================================================== t_iter = iteration / self.FLAGS.iter # Image-space loss, split into a coverage component and a color component color_ref = target['img'] img_loss = torch.nn.functional.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:]) img_loss += loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:]) reg_loss = torch.tensor([0], dtype=torch.float32, device="cuda") # Compute regularizer. if self.FLAGS.laplace == "absolute": reg_loss += regularizer.laplace_regularizer_const(self.mesh.v_pos, self.mesh.t_pos_idx) * self.FLAGS.laplace_scale * (1 - t_iter) elif self.FLAGS.laplace == "relative": reg_loss += regularizer.laplace_regularizer_const(self.mesh.v_pos - self.initial_guess.v_pos, self.mesh.t_pos_idx) * self.FLAGS.laplace_scale * (1 - t_iter) # Albedo (k_d) smoothnesss regularizer reg_loss += torch.mean(buffers['kd_grad'][..., :-1] * buffers['kd_grad'][..., -1:]) * 0.03 * min(1.0, iteration / 500) # Visibility regularizer reg_loss += torch.mean(buffers['occlusion'][..., :-1] * buffers['occlusion'][..., -1:]) * 0.001 * min(1.0, iteration / 500) # Light white balance regularizer reg_loss = reg_loss + lgt.regularizer() * 0.005 return img_loss, reg_loss