# import pytorch3d import torch from einops import rearrange from torch._C import device def edges_to_sparse_incidence(edges, num_vertices): num_edges = edges.shape[0] row_indexes = torch.arange(num_edges, dtype=torch.long, device=edges.device).repeat_interleave(2) col_indexes = edges.reshape(-1) indexes = torch.stack([row_indexes, col_indexes]) values = torch.FloatTensor([1, -1]).to(edges.device).repeat(num_edges) return torch.sparse.FloatTensor(indexes, values, torch.Size([num_edges, num_vertices])) def compute_svd_rotation(vertices_rest_pose, vertices_deformed_pose, incidence_mat): """ Adapted from: https://github.com/kzhou23/shape_pose_disent/blob/a8017c405892c98f52fa9775327172633290b1d8/arap.py#L76 vertices_rest_pose: B x V x D vertices_deformed_pose: B x V x D incidence_mat: E x V """ batch_size, num_vertices, dimensions = vertices_rest_pose.shape vertices = torch.cat((vertices_rest_pose, vertices_deformed_pose), dim=0) # 2B x V x D -> V x (D x 2B) vertices = rearrange(vertices, 'a v d -> v (d a)') # E x V . V x (D x 2B) - > E x (D x 2B) edges = torch.sparse.mm(incidence_mat, vertices) edges = rearrange(edges, 'e (d a) -> a e d', d=dimensions) rest_edges, deformed_edges = torch.split(edges, batch_size, dim=0) edges_outer = torch.matmul(rest_edges[:, :, :, None], deformed_edges[:, :, None, :]) edges_outer = rearrange(edges_outer, 'b e d1 d2 -> e (b d1 d2)') abs_incidence_mat = incidence_mat.clone() abs_incidence_mat._values()[:] = torch.abs(abs_incidence_mat._values()) # transposed S S = torch.sparse.mm(abs_incidence_mat.t(), edges_outer) S = rearrange(S, 'v (b d1 d2) -> b v d2 d1', v=num_vertices, b=batch_size, d1=dimensions, d2=dimensions) # SVD on gpu is extremely slow! https://github.com/pytorch/pytorch/pull/48436 device = S.device U, _, V = torch.svd(S.cpu()) U = U.to(device) V = V.to(device) det_sign = torch.det(torch.matmul(U, V.transpose(-2, -1))) U = torch.cat([U[..., :-1], U[..., -1:] * det_sign[..., None, None]], axis=-1) rotations = torch.matmul(U, V.transpose(-2, -1)) return rotations def compute_rotation(vertices_rest_pose, vertices_deformed_pose, edges): """ vertices_rest_pose: B x V x D vertices_deformed_pose: B x V x D edges: E x 2 """ num_vertices = vertices_rest_pose.shape[1] incidence_mat = edges_to_sparse_incidence(edges, num_vertices) rot = compute_svd_rotation(vertices_rest_pose, vertices_deformed_pose, incidence_mat) rot = pytorch3d.transforms.matrix_to_quaternion(rot) return rot def quaternion_normalize(quaternion, eps=1e-12): """ Adapted from tensorflow_graphics Normalizes a quaternion. Note: In the following, A1 to An are optional batch dimensions. Args: quaternion: A tensor of shape `[A1, ..., An, 4]`, where the last dimension represents a quaternion. eps: A lower bound value for the norm that defaults to 1e-12. name: A name for this op that defaults to "quaternion_normalize". Returns: A N-D tensor of shape `[?, ..., ?, 1]` where the quaternion elements have been normalized. Raises: ValueError: If the shape of `quaternion` is not supported. """ return l2_normalize(quaternion, dim=-1, epsilon=eps) def l2_normalize(x, dim=-1, epsilon=1e-12): square_sum = torch.sum(x ** 2, dim=dim, keepdim=True) x_inv_norm = torch.rsqrt(torch.clamp(square_sum, min=epsilon)) return x * x_inv_norm def arap_energy(vertices_rest_pose, vertices_deformed_pose, quaternions, edges, vertex_weight=None, edge_weight=None, conformal_energy=True, aggregate_loss=True): """ Adapted from tensorflow_graphics Estimates an As Conformal As Possible (ACAP) fitting energy. For a given mesh in rest pose, this function evaluates a variant of the ACAP [1] fitting energy for a batch of deformed meshes. The vertex weights and edge weights are defined on the rest pose. The method implemented here is similar to [2], but with an added free variable capturing a scale factor per vertex. [1]: Yusuke Yoshiyasu, Wan-Chun Ma, Eiichi Yoshida, and Fumio Kanehiro. "As-Conformal-As-Possible Surface Registration." Computer Graphics Forum. Vol. 33. No. 5. 2014.
[2]: Olga Sorkine, and Marc Alexa. "As-rigid-as-possible surface modeling". Symposium on Geometry Processing. Vol. 4. 2007. Note: In the description of the arguments, V corresponds to the number of vertices in the mesh, and E to the number of edges in this mesh. Note: In the following, A1 to An are optional batch dimensions. Args: vertices_rest_pose: A tensor of shape `[V, 3]` containing the position of all the vertices of the mesh in rest pose. vertices_deformed_pose: A tensor of shape `[A1, ..., An, V, 3]` containing the position of all the vertices of the mesh in deformed pose. quaternions: A tensor of shape `[A1, ..., An, V, 4]` defining a rigid transformation to apply to each vertex of the rest pose. See Section 2 from [1] for further details. edges: A tensor of shape `[E, 2]` defining indices of vertices that are connected by an edge. vertex_weight: An optional tensor of shape `[V]` defining the weight associated with each vertex. Defaults to a tensor of ones. edge_weight: A tensor of shape `[E]` defining the weight of edges. Common choices for these weights include uniform weighting, and cotangent weights. Defaults to a tensor of ones. conformal_energy: A `bool` indicating whether each vertex is associated with a scale factor or not. If this parameter is True, scaling information must be encoded in the norm of `quaternions`. If this parameter is False, this function implements the energy described in [2]. aggregate_loss: A `bool` defining whether the returned loss should be an aggregate measure. When True, the mean squared error is returned. When False, returns two losses for every edge of the mesh. name: A name for this op. Defaults to "as_conformal_as_possible_energy". Returns: When aggregate_loss is `True`, returns a tensor of shape `[A1, ..., An]` containing the ACAP energies. When aggregate_loss is `False`, returns a tensor of shape `[A1, ..., An, 2*E]` containing each term of the summation described in the equation 7 of [2]. Raises: ValueError: if the shape of `vertices_rest_pose`, `vertices_deformed_pose`, `quaternions`, `edges`, `vertex_weight`, or `edge_weight` is not supported. """ # with tf.compat.v1.name_scope(name, "as_conformal_as_possible_energy", [ # vertices_rest_pose, vertices_deformed_pose, quaternions, edges, # conformal_energy, vertex_weight, edge_weight # ]): # vertices_rest_pose = tf.convert_to_tensor(value=vertices_rest_pose) # vertices_deformed_pose = tf.convert_to_tensor(value=vertices_deformed_pose) # quaternions = tf.convert_to_tensor(value=quaternions) # edges = tf.convert_to_tensor(value=edges) # if vertex_weight is not None: # vertex_weight = tf.convert_to_tensor(value=vertex_weight) # if edge_weight is not None: # edge_weight = tf.convert_to_tensor(value=edge_weight) # shape.check_static( # tensor=vertices_rest_pose, # tensor_name="vertices_rest_pose", # has_rank=2, # has_dim_equals=(-1, 3)) # shape.check_static( # tensor=vertices_deformed_pose, # tensor_name="vertices_deformed_pose", # has_rank_greater_than=1, # has_dim_equals=(-1, 3)) # shape.check_static( # tensor=quaternions, # tensor_name="quaternions", # has_rank_greater_than=1, # has_dim_equals=(-1, 4)) # shape.compare_batch_dimensions( # tensors=(vertices_deformed_pose, quaternions), # last_axes=(-3, -3), # broadcast_compatible=False) # shape.check_static( # tensor=edges, tensor_name="edges", has_rank=2, has_dim_equals=(-1, 2)) # tensors_with_vertices = [vertices_rest_pose, # vertices_deformed_pose, # quaternions] # names_with_vertices = ["vertices_rest_pose", # "vertices_deformed_pose", # "quaternions"] # axes_with_vertices = [-2, -2, -2] # if vertex_weight is not None: # shape.check_static( # tensor=vertex_weight, tensor_name="vertex_weight", has_rank=1) # tensors_with_vertices.append(vertex_weight) # names_with_vertices.append("vertex_weight") # axes_with_vertices.append(0) # shape.compare_dimensions( # tensors=tensors_with_vertices, # axes=axes_with_vertices, # tensor_names=names_with_vertices) # if edge_weight is not None: # shape.check_static( # tensor=edge_weight, tensor_name="edge_weight", has_rank=1) # shape.compare_dimensions( # tensors=(edges, edge_weight), # axes=(0, 0), # tensor_names=("edges", "edge_weight")) if not conformal_energy: quaternions = quaternion_normalize(quaternions) # Extracts the indices of vertices. indices_i, indices_j = torch.unbind(edges, dim=-1) # Extracts the vertices we need per term. vertices_i_rest = vertices_rest_pose[..., indices_i, :] vertices_j_rest = vertices_rest_pose[..., indices_j, :] vertices_i_deformed = vertices_deformed_pose[..., indices_i, :] vertices_j_deformed = vertices_deformed_pose[..., indices_j, :] # Extracts the weights we need per term. weights_shape = vertices_i_rest.shape[-2] if vertex_weight is not None: weight_i = vertex_weight[indices_i] weight_j = vertex_weight[indices_j] else: weight_i = weight_j = torch.ones(weights_shape, dtype=vertices_rest_pose.dtype, device=vertices_rest_pose.device) weight_i = weight_i[..., None] weight_j = weight_j[..., None] if edge_weight is not None: weight_ij = edge_weight else: weight_ij = torch.ones(weights_shape, dtype=vertices_rest_pose.dtype, device=vertices_rest_pose.device) weight_ij = weight_ij[..., None] # Extracts the rotation we need per term. quaternion_i = quaternions[..., indices_i, :] quaternion_j = quaternions[..., indices_j, :] # Computes the energy. deformed_ij = vertices_i_deformed - vertices_j_deformed rotated_rest_ij = pytorch3d.transforms.quaternion_apply(quaternion_i, (vertices_i_rest - vertices_j_rest)) energy_ij = weight_i * weight_ij * (deformed_ij - rotated_rest_ij) deformed_ji = vertices_j_deformed - vertices_i_deformed rotated_rest_ji = pytorch3d.transforms.quaternion_apply(quaternion_j, (vertices_j_rest - vertices_i_rest)) energy_ji = weight_j * weight_ij * (deformed_ji - rotated_rest_ji) energy_ij_squared = torch.sum(energy_ij ** 2, dim=-1) energy_ji_squared = torch.sum(energy_ji ** 2, dim=-1) if aggregate_loss: average_energy_ij = torch.mean(energy_ij_squared, dim=-1) average_energy_ji = torch.mean(energy_ji_squared, dim=-1) return (average_energy_ij + average_energy_ji) / 2.0 return torch.cat((energy_ij_squared, energy_ji_squared), dim=-1) def arap_loss(vertices_rest_pose, vertices_deformed_pose, edges): # squash batch dimensions vertices_rest_pose_shape = list(vertices_rest_pose.shape) vertices_deformed_pose_shape = list(vertices_deformed_pose.shape) vertices_rest_pose = vertices_rest_pose.reshape([-1] + vertices_rest_pose_shape[-2:]) vertices_deformed_pose = vertices_deformed_pose.reshape([-1] + vertices_deformed_pose_shape[-2:]) # try: quaternions = compute_rotation(vertices_rest_pose, vertices_deformed_pose, edges) # except RuntimeError: # print('SVD did not converge') # batch_size = vertices_rest_pose.shape[0] # num_vertices = vertices_rest_pose.shape[-2] # quaternions = pytorch3d.transforms.matrix_to_quaternion(pytorch3d.transforms.euler_angles_to_matrix(torch.zeros([batch_size, num_vertices, 3], device=vertices_rest_pose.device), 'XYZ')) quaternions = quaternions.detach() energy = arap_energy( vertices_rest_pose, vertices_deformed_pose, quaternions, edges, aggregate_loss=True, conformal_energy=False) return energy.reshape(vertices_rest_pose_shape[:-2])