import numpy as np import matplotlib.pyplot as plt import torch import torch.nn as nn # import pytorch3d # import pytorch3d.loss # import pytorch3d.renderer # import pytorch3d.structures # import pytorch3d.io # import pytorch3d.transforms from PIL import Image from .utils import sphere from einops import rearrange def update_camera_pose(cameras, position, at): cameras.R = pytorch3d.renderer.look_at_rotation(position, at).to(cameras.device) cameras.T = -torch.bmm(cameras.R.transpose(1, 2), position[:, :, None])[:, :, 0] def get_soft_rasterizer_settings(image_size, sigma=1e-6, gamma=1e-6, faces_per_pixel=30): blend_params = pytorch3d.renderer.BlendParams(sigma=sigma, gamma=gamma) settings = pytorch3d.renderer.RasterizationSettings( image_size=image_size, blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, faces_per_pixel=faces_per_pixel, ) return settings, blend_params class Renderer(nn.Module): def __init__(self, cfgs): super().__init__() self.cfgs = cfgs self.device = cfgs.get('device', 'cpu') self.image_size = cfgs.get('out_image_size', 64) self.full_size_h = cfgs.get('full_size_h', 1080) self.full_size_w = cfgs.get('full_size_w', 1920) self.fov_w = cfgs.get('fov_w', 60) # self.fov_h = cfgs.get('fov_h', 30) self.fov_h = np.arctan(np.tan(self.fov_w /2 /180*np.pi) / self.full_size_w * self.full_size_h) *2 /np.pi*180 self.crop_fov_approx = cfgs.get('crop_fov_approx', 25) self.cam_pos_z_offset = cfgs.get('cam_pos_z_offset', 10.) self.max_range = np.tan(min(self.fov_h, self.fov_w) /2 /180 * np.pi) * self.cam_pos_z_offset cam_pos = torch.FloatTensor([[0, 0, self.cam_pos_z_offset]]).to(self.device) cam_at = torch.FloatTensor([[0, 0, 0]]).to(self.device) self.rot_rep = cfgs.get('rot_rep', 'euler_angle') # self.cameras = pytorch3d.renderer.FoVPerspectiveCameras(fov=self.crop_fov_approx).to(self.device) # update_camera_pose(self.cameras, position=cam_pos, at=cam_at) # self.full_cameras = pytorch3d.renderer.FoVPerspectiveCameras(fov=self.fov_w).to(self.device) # update_camera_pose(self.full_cameras, position=cam_pos, at=cam_at) self.image_renderer = self._create_image_renderer() self.ico_sphere_subdiv = cfgs.get('ico_sphere_subdiv', 2) self.init_shape_scale_xy = float(cfgs.get('init_shape_scale_xy', 1.)) self.init_shape_scale_z = float(cfgs.get('init_shape_scale_z', 1.)) # init_verts, init_faces, init_aux = pytorch3d.io.load_obj(cfgs['init_shape_obj'], create_texture_atlas=True, device=self.device) # self.init_verts = init_verts *self.init_shape_scale # self.meshes = pytorch3d.structures.Meshes(verts=[self.init_verts], faces=[init_faces.verts_idx]).to(self.device) # self.tex_faces_uv = init_faces.textures_idx.unsqueeze(0) # self.tex_verts_uv = init_aux.verts_uvs.unsqueeze(0) # self.texture_atlas = init_aux.texture_atlas.unsqueeze(0) # self.num_verts_total = init_verts.size(0) # cmap = plt.cm.get_cmap('hsv', self.num_verts_total) # verts_texture = cmap(np.random.permutation(self.num_verts_total))[:,:3] # self.verts_texture = torch.FloatTensor(verts_texture) # debug_uvtex = cfgs.get('debug_uvtex', None) # if debug_uvtex is not None: # face_tex_map = Image.open(debug_uvtex).convert('RGB').resize((512, 512)) # self.face_tex_map = torch.FloatTensor(np.array(face_tex_map)).permute(2,0,1) / 255. # else: # self.face_tex_map = None meshes, aux = sphere.get_symmetric_ico_sphere(subdiv=self.ico_sphere_subdiv, return_tex_uv=True, return_face_tex_map=True, device=self.device) init_verts = meshes.verts_padded() self.init_verts = init_verts * torch.FloatTensor([self.init_shape_scale_xy, self.init_shape_scale_xy, self.init_shape_scale_z]).view(1,1,3).to(init_verts.device) # TODO: is this needed? self.meshes = meshes.update_padded(init_verts * 0) self.tex_faces_uv = aux['face_tex_ids'].unsqueeze(0) self.tex_verts_uv = aux['verts_tex_uv'].unsqueeze(0) self.face_tex_map = aux['face_tex_map'].permute(2,0,1) self.tex_map_seam_mask = aux['seam_mask'].permute(2,0,1) self.num_verts_total = init_verts.size(1) self.num_verts_seam = aux['num_verts_seam'] self.num_verts_one_side = aux['num_verts_one_side'] # hack to turn off texture symmetry if cfgs.get('disable_sym_tex', False): tex_uv_seam1 = self.tex_verts_uv[:,:aux['num_verts_seam']].clone() tex_uv_seam1[:,:,0] = tex_uv_seam1[:,:,0] /2 + 0.5 tex_uv_side1 = self.tex_verts_uv[:,aux['num_verts_seam']:aux['num_verts_seam']+aux['num_verts_one_side']].clone() tex_uv_side1[:,:,0] = tex_uv_side1[:,:,0] /2 + 0.5 tex_uv_seam2 = self.tex_verts_uv[:,:aux['num_verts_seam']].clone() tex_uv_seam2[:,:,0] = tex_uv_seam2[:,:,0] /2 tex_uv_side2 = self.tex_verts_uv[:,aux['num_verts_seam']+aux['num_verts_one_side']:].clone() tex_uv_side2[:,:,0] = tex_uv_side2[:,:,0] /2 self.tex_verts_uv = torch.cat([tex_uv_seam1, tex_uv_side1, tex_uv_side2, tex_uv_seam2], 1) num_faces = self.tex_faces_uv.shape[1] face_tex_ids1 = self.tex_faces_uv[:, :num_faces//2].clone() face_tex_ids2 = self.tex_faces_uv[:, num_faces//2:].clone() face_tex_ids2[face_tex_ids2 < aux['num_verts_seam']] += aux['num_verts_seam'] + 2*aux['num_verts_one_side'] self.tex_faces_uv = torch.cat([face_tex_ids1, face_tex_ids2], 1) self.face_tex_map = torch.cat([self.face_tex_map, self.face_tex_map.flip(2)], 2) self.tex_map_seam_mask = torch.cat([self.tex_map_seam_mask, self.tex_map_seam_mask.flip(2)], 2) def _create_silhouette_renderer(self): settings, blend_params = get_soft_rasterizer_settings(self.image_size) return pytorch3d.renderer.MeshRenderer( rasterizer=pytorch3d.renderer.MeshRasterizer(cameras=self.cameras, raster_settings=settings), shader=pytorch3d.renderer.SoftSilhouetteShader(cameras=self.cameras, blend_params=blend_params) ) def _create_image_renderer(self): settings, blend_params = get_soft_rasterizer_settings(self.image_size) lights = pytorch3d.renderer.DirectionalLights(device=self.device, ambient_color=((1., 1., 1.),), diffuse_color=((0., 0., 0.),), specular_color=((0., 0., 0.),), direction=((0, 1, 0),)) return pytorch3d.renderer.MeshRenderer( rasterizer=pytorch3d.renderer.MeshRasterizer(cameras=self.cameras, raster_settings=settings), shader=pytorch3d.renderer.SoftPhongShader(device=self.device, lights=lights, cameras=self.cameras, blend_params=blend_params) ) def transform_verts(self, verts, pose): b, f, _ = pose.shape if self.rot_rep == 'euler_angle' or self.rot_rep == 'soft_calss': rot_mat = pytorch3d.transforms.euler_angles_to_matrix(pose[...,:3].view(-1,3), convention='XYZ') tsf = pytorch3d.transforms.Rotate(rot_mat, device=pose.device) elif self.rot_rep == 'quaternion': rot_mat = pytorch3d.transforms.quaternion_to_matrix(pose[...,:4].view(-1,4)) tsf = pytorch3d.transforms.Rotate(rot_mat, device=pose.device) elif self.rot_rep == 'lookat': rot_mat = pose[...,:9].view(-1,3,3) tsf = pytorch3d.transforms.Rotate(rot_mat, device=pose.device) else: raise NotImplementedError tsf = tsf.compose(pytorch3d.transforms.Translate(pose[...,-3:].view(-1,3), device=pose.device)) new_verts = tsf.transform_points(verts.view(b*f, *verts.shape[2:])) return new_verts.view(b, f, *new_verts.shape[1:]) # def transform_mesh(self, mesh, pose): # mesh_verts = mesh.verts_padded() # new_mesh_verts = self.transform_verts(mesh_verts, pose) # new_mesh = mesh.update_padded(new_mesh_verts) # return new_mesh def symmetrize_shape(self, shape): verts_seam = shape[:,:,:self.num_verts_seam] * torch.FloatTensor([0,1,1]).to(shape.device) verts_one_side = shape[:,:,self.num_verts_seam:self.num_verts_seam+self.num_verts_one_side] * torch.FloatTensor([1,1,1]).to(shape.device) verts_other_side = verts_one_side * torch.FloatTensor([-1,1,1]).to(shape.device) shape = torch.cat([verts_seam, verts_one_side, verts_other_side], 2) return shape def get_deformed_mesh(self, shape, pose=None, return_shape=False): b, f, _, _ = shape.shape if pose is not None: shape = self.transform_verts(shape, pose) mesh = self.meshes.extend(b*f) mesh = mesh.update_padded(rearrange(shape, 'b f ... -> (b f) ...')) if return_shape: return shape, mesh else: return mesh def get_textures(self, tex_im): b, f, c, h, w = tex_im.shape ## top half texture map in ico_sphere.obj is unused, pad with zeros # if 'sym' not in self.cfgs.get('init_shape_obj', ''): # tex_im = torch.cat([torch.zeros_like(tex_im), tex_im], 3) # tex_im = nn.functional.interpolate(tex_im, (h, w), mode='bilinear', align_corners=False) textures = pytorch3d.renderer.TexturesUV(maps=tex_im.view(b*f, *tex_im.shape[2:]).permute(0, 2, 3, 1), # texture maps are BxHxWx3 faces_uvs=self.tex_faces_uv.repeat(b*f, 1, 1), verts_uvs=self.tex_verts_uv.repeat(b*f, 1, 1)) return textures def render_flow(self, meshes, shape, pose, deformed_shape=None): # verts = meshes.verts_padded() # (B*F)xVx3 b, f, _, _ = shape.shape if f < 2: return None if deformed_shape is None: deformed_shape, meshes = self.get_deformed_mesh(shape.detach(), pose=pose, return_shape=True) im_size = torch.FloatTensor([self.image_size, self.image_size]).to(shape.device) # (w,h) verts_2d = self.cameras.transform_points_screen(deformed_shape.view(b*f, *deformed_shape.shape[2:]), im_size.view(1,2).repeat(b*f,1), eps=1e-7) verts_2d = verts_2d.view(b, f, *verts_2d.shape[1:]) verts_flow = verts_2d[:, 1:, :, :2] - verts_2d[:, :-1, :, :2] # Bx(F-1)xVx(x,y) verts_flow = verts_flow / im_size.view(1, 1, 1, 2) * 0.5 + 0.5 # 0~1 flow_tex = torch.nn.functional.pad(verts_flow, pad=[0, 1, 0, 0, 0, 1]) # BxFxVx3 # meshes = meshes.detach() # detach mesh when rendering flow (only texture has gradients) # meshes = self.get_deformed_mesh(shape.detach()) meshes.textures = pytorch3d.renderer.TexturesVertex(verts_features=flow_tex.view(b*f, -1, 3)) flow = self.image_renderer(meshes_world=meshes, cameras=self.cameras) # settings, blend_params = get_soft_rasterizer_settings(image_size=self.image_size, sigma=1e-6, gamma=1e-6, faces_per_pixel=5) # flow = self.image_renderer(meshes_world=meshes, cameras=self.cameras, raster_settings=settings, blend_params=blend_params) flow = flow.view(b, f, *flow.shape[1:])[:, :-1] # Bx(F-1)xHxWx3 flow_mask = (flow[:, :, :, :, 3:] > 0.01).float() return (flow[:, :, :, :, :2] - 0.5) * 2 * flow_mask # Bx(F-1)xHxWx2 def forward(self, pose, texture, shape, crop_bbox=None, render_flow=True): b, f, _ = pose.shape ## compensate crop with intrinsics, assuming square crops # x0, y0, w, h = crop_bbox.unbind(2) # fx = 1 / np.tan(self.fov_w / 2 /180*np.pi) # fy = fx # sx = w / self.full_size_w # sy = sx # cx = ((x0+w/2) - (self.full_size_w/2)) / (self.full_size_w/2) # [0-w] -> [-1,1] # cy = ((y0+h/2) - (self.full_size_h/2)) / (self.full_size_w/2) # znear = 1 # zfar = 100 # v1 = zfar / (zfar - znear) # v2 = -(zfar * znear) / (zfar - znear) # # # K = [[[ fx/sx, 0.0000, cx/sx, 0.0000], # # [ 0.0000, fy/sy, cy/sy, 0.0000], # # [ 0.0000, 0.0000, v1, v2], # # [ 0.0000, 0.0000, 1.0000, 0.0000]]] # zeros = torch.zeros_like(sx) # K_row1 = torch.stack([fx/sx, zeros, cx/sx, zeros], 2) # K_row2 = torch.stack([zeros, fy/sy, cy/sy, zeros], 2) # K_row3 = torch.stack([zeros, zeros, zeros+v1, zeros+v2], 2) # K_row4 = torch.stack([zeros, zeros, zeros+1, zeros], 2) # K = torch.stack([K_row1, K_row2, K_row3, K_row4], 2) # BxFx4x4 # self.crop_cameras = pytorch3d.renderer.FoVPerspectiveCameras(K=K.view(-1, 4, 4), R=self.cameras.R, T=self.cameras.T, device=self.device) # # reset znear, zfar to scalar to bypass broadcast bug in pytorch3d blending # self.crop_cameras.znear = znear # self.crop_cameras.zfar = zfar deformed_shape, mesh = self.get_deformed_mesh(shape, pose=pose, return_shape=True) if render_flow: flow = self.render_flow(mesh, shape, pose, deformed_shape=deformed_shape) # Bx(F-1)xHxWx2 # flow = self.render_flow(mesh, shape, pose, deformed_shape=None) # Bx(F-1)xHxWx2 else: flow = None mesh.textures = self.get_textures(texture) image = self.image_renderer(meshes_world=mesh, cameras=self.cameras) image = image.view(b, f, *image.shape[1:]) return image, flow, mesh