kyleleey
first commit
98a77e0
raw
history blame
No virus
1.71 kB
import torch
from einops import repeat
def sample_farthest_points(pts, k, return_index=False):
b, c, n = pts.shape
farthest_pts = torch.zeros((b, 3, k), device=pts.device, dtype=pts.dtype)
indexes = torch.zeros((b, k), device=pts.device, dtype=torch.int64)
index = torch.randint(n, [b], device=pts.device)
gather_index = repeat(index, 'b -> b c 1', c=c)
farthest_pts[:, :, 0] = torch.gather(pts, 2, gather_index)[:, :, 0]
indexes[:, 0] = index
distances = torch.norm(farthest_pts[:, :, 0][:, :, None] - pts, dim=1)
for i in range(1, k):
_, index = torch.max(distances, dim=1)
gather_index = repeat(index, 'b -> b c 1', c=c)
farthest_pts[:, :, i] = torch.gather(pts, 2, gather_index)[:, :, 0]
indexes[:, i] = index
distances = torch.min(distances, torch.norm(farthest_pts[:, :, i][:, :, None] - pts, dim=1))
if return_index:
return farthest_pts, indexes
else:
return farthest_pts
def line_segment_distance(a, b, points, sqrt=True):
"""
compute the distance between a point and a line segment defined by a and b
a, b: ... x D
points: ... x D
"""
def sumprod(x, y, keepdim=True):
return torch.sum(x * y, dim=-1, keepdim=keepdim)
a, b = a[..., None, :], b[..., None, :]
t_min = sumprod(points - a, b - a) / torch.max(sumprod(b - a, b - a), torch.tensor(1e-6, device=a.device))
t_line = torch.clamp(t_min, 0.0, 1.0)
# closest points on the line to every point
s = a + t_line * (b - a)
distance = sumprod(s - points, s - points, keepdim=False)
if sqrt:
distance = torch.sqrt(distance + 1e-6)
return distance