kyleleey
first commit
98a77e0
raw
history blame contribute delete
No virus
6.35 kB
import io
import numpy as np
import cv2
from PIL import Image
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import torch
# import pytorch3d
# import pytorch3d.renderer
# import pytorch3d.structures
# import pytorch3d.io
# import pytorch3d.transforms
# import pytorch3d.utils
## https://stackoverflow.com/a/58641662/11471407
def fig_to_img(fig, dpi=200, im_size=(512,512)):
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=dpi)
buf.seek(0)
img = np.array(Image.open(buf).convert('RGB').resize(im_size)) / 255.
return img
def get_ico_sphere(subdiv=1):
return pytorch3d.utils.ico_sphere(level=subdiv)
def get_symmetric_ico_sphere(subdiv=1, return_tex_uv=True, return_face_tex_map=True, device='cpu'):
sph_mesh = get_ico_sphere(subdiv=subdiv)
sph_verts = sph_mesh.verts_padded()[0]
sph_faces = sph_mesh.faces_padded()[0]
## rotate the default mesh s.t. the seam is exactly on yz-plane
rot_z = np.arctan(0.5000/0.3090) # computed from vertices in ico_sphere
tfs = pytorch3d.transforms.RotateAxisAngle(rot_z, 'Z', degrees=False)
rotated_verts = tfs.transform_points(sph_verts)
## identify vertices on each side and on the seam
verts_id_seam = []
verts_id_one_side = []
verts_id_other_side = []
for i, v in enumerate(rotated_verts):
## on the seam, x=0
if v[0].abs() < 0.001: # threshold 0.001
verts_id_seam += [i]
rotated_verts[i][0] = 0. # force it to be 0
## right side, x>0
elif v[0] > 0:
verts_id_one_side += [i]
## left side, x<0
else:
verts_id_other_side += [i]
## create a new set of symmetric vertices
new_vid = 0
vid_old_to_new = {}
verts_seam = []
for vid in verts_id_seam:
verts_seam += [rotated_verts[vid]]
vid_old_to_new[vid] = new_vid
new_vid += 1
verts_seam = torch.stack(verts_seam, 0)
verts_one_side = []
for vid in verts_id_one_side:
verts_one_side += [rotated_verts[vid]]
vid_old_to_new[vid] = new_vid
new_vid += 1
verts_one_side = torch.stack(verts_one_side, 0)
verts_other_side = []
for vid in verts_id_one_side:
verts_other_side += [rotated_verts[vid] * torch.FloatTensor([-1,1,1])] # flip x
new_vid += 1
verts_other_side = torch.stack(verts_other_side, 0)
new_verts = torch.cat([verts_seam, verts_one_side, verts_other_side], 0)
## create a new set of symmetric faces
faces_one_side = []
faces_other_side = []
for old_face in sph_faces:
new_face1 = [] # one side
new_face2 = [] # the other side
for vi in old_face:
vi = vi.item()
if vi in verts_id_seam:
new_face1 += [vid_old_to_new[vi]]
new_face2 += [vid_old_to_new[vi]]
elif vi in verts_id_one_side:
new_face1 += [vid_old_to_new[vi]]
new_face2 += [vid_old_to_new[vi]+len(verts_id_one_side)] # assuming the symmetric vertices are appended right after the original ones
else:
break
if len(new_face1) == 3: # no vert on the other side
faces_one_side += [new_face1]
faces_other_side += [new_face2[::-1]] # reverse face orientation
new_faces = faces_one_side + faces_other_side
new_faces = torch.LongTensor(new_faces)
sym_sph_mesh = pytorch3d.structures.Meshes(verts=[new_verts], faces=[new_faces])
aux = {}
aux['num_verts_seam'] = len(verts_seam)
aux['num_verts_one_side'] = len(verts_one_side)
## create texture map uv
if return_tex_uv:
verts_tex_uv = torch.stack([-new_verts[:,2], new_verts[:,1]], 1) # -z,y
verts_tex_uv = verts_tex_uv / ((verts_tex_uv**2).sum(1,keepdim=True)**0.5).clamp(min=1e-8)
magnitude = new_verts[:,:1].abs().acos() # set magnitude to angle deviation from vertical axis, for more even texture mapping
magnitude = magnitude / magnitude.max() *0.95 # max 0.95
verts_tex_uv = verts_tex_uv * magnitude
verts_tex_uv = verts_tex_uv /2 + 0.5 # rescale to 0~1
face_tex_ids = new_faces
aux['verts_tex_uv'] = verts_tex_uv.to(device)
aux['face_tex_ids'] = face_tex_ids.to(device)
## create face color map
if return_face_tex_map:
dpi = 200
im_size = (512, 512)
fig = plt.figure(figsize=(8,8), dpi=dpi, frameon=False)
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
num_colors = 10
cmap = plt.get_cmap('tab10', num_colors)
num_faces = len(face_tex_ids)
face_tex_ids_one_side = face_tex_ids[:num_faces//2] # assuming symmetric faces are appended right after the original ones
for i, face in enumerate(face_tex_ids_one_side):
vert_uv = verts_tex_uv[face] # 3x2
# color = cmap(i%num_colors)
color = cmap(np.random.randint(num_colors))
t = plt.Polygon(vert_uv, facecolor=color, edgecolor='black', linewidth=2)
ax.add_patch(t)
## draw arrow
ax.arrow(0.85, 0.5, -0.7, 0., length_includes_head=True, width=0.03, head_width=0.15, overhang=0.2, color='white')
ax.set_xlim(0,1)
ax.set_ylim(0,1)
face_tex_map = torch.FloatTensor(fig_to_img(fig, dpi, im_size))
plt.close()
## draw seam
fig = plt.figure(figsize=(8,8), dpi=dpi, frameon=False)
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
for i, face in enumerate(face_tex_ids_one_side):
vert_uv = verts_tex_uv[face] # 3x2
vert_on_seam = ((vert_uv-0.5)**2).sum(1)**0.5 > 0.47
if vert_on_seam.sum() == 2:
ax.plot(*vert_uv[vert_on_seam].t(), color='black', linewidth=10)
ax.set_xlim(0,1)
ax.set_ylim(0,1)
seam_mask = torch.FloatTensor(fig_to_img(fig, dpi, im_size))
plt.close()
seam_mask = (seam_mask[:,:,:1] < 0.1).float()
red = torch.FloatTensor([1,0,0]).view(1,1,3)
face_tex_map = seam_mask * red + (1-seam_mask) * face_tex_map
aux['face_tex_map'] = face_tex_map.to(device)
aux['seam_mask'] = seam_mask.to(device)
return sym_sph_mesh.to(device), aux