kyleleey commited on
Commit
3bf37d0
1 Parent(s): d2d9973
Files changed (3) hide show
  1. .gitignore +4 -13
  2. app.py +619 -0
  3. requirements.txt +32 -0
.gitignore CHANGED
@@ -1,22 +1,13 @@
1
  __pycache__
2
- data
3
- data/*/
4
- data/*/*
5
- !data/preprocessing/
6
  pretrained/*/
7
  results
8
  neural_renderer
9
  *.zip
10
  unchanged/
11
- cvpr23_results/
12
- # slurm.bash
13
- results
14
- results/*/
15
- results/*
16
- results/*/*
17
- results/dor_checkpoints/*
18
- results/dor_checkpoints/*/*
19
- results/dor_checkpoints/*/*/*
20
 
21
 
22
  .vscode
 
1
  __pycache__
2
+ # data
3
+ # data/*/
4
+ # data/*/*
5
+
6
  pretrained/*/
7
  results
8
  neural_renderer
9
  *.zip
10
  unchanged/
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  .vscode
app.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import fire
3
+ import gradio as gr
4
+ from PIL import Image
5
+ from functools import partial
6
+ import argparse
7
+
8
+ import cv2
9
+ import time
10
+ import numpy as np
11
+ import trimesh
12
+ from segment_anything import sam_model_registry, SamPredictor
13
+
14
+ import random
15
+ from pytorch3d import transforms
16
+ import torch
17
+ import torchvision
18
+ import torch.distributed as dist
19
+ import nvdiffrast.torch as dr
20
+ from video3d.model_ddp import Unsup3DDDP, forward_to_matrix
21
+ from video3d.trainer_few_shot import Fewshot_Trainer
22
+ from video3d.trainer_ddp import TrainerDDP
23
+ from video3d import setup_runtime
24
+ from video3d.render.mesh import make_mesh
25
+ from video3d.utils.skinning_v4 import estimate_bones, skinning, euler_angles_to_matrix
26
+ from video3d.utils.misc import save_obj
27
+ from video3d.render import util
28
+ import matplotlib.pyplot as plt
29
+ from pytorch3d import utils, renderer, transforms, structures, io
30
+ from video3d.render.render import render_mesh
31
+ from video3d.render.material import texture as material_texture
32
+
33
+
34
+ _TITLE = '''Learning the 3D Fauna of the Web'''
35
+ _DESCRIPTION = '''
36
+ <div>
37
+ Reconstruct any quadruped animal from one image.
38
+ </div>
39
+ <div>
40
+ The demo only contains the 3D reconstruction part.
41
+ </div>
42
+ '''
43
+ _GPU_ID = 0
44
+
45
+ if not hasattr(Image, 'Resampling'):
46
+ Image.Resampling = Image
47
+
48
+
49
+ def sam_init():
50
+ sam_checkpoint = os.path.join(os.path.dirname(__file__), "sam_pt", "sam_vit_h_4b8939.pth")
51
+ model_type = "vit_h"
52
+
53
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=f"cuda:{_GPU_ID}")
54
+ predictor = SamPredictor(sam)
55
+ return predictor
56
+
57
+
58
+ def sam_segment(predictor, input_image, *bbox_coords):
59
+ bbox = np.array(bbox_coords)
60
+ image = np.asarray(input_image)
61
+
62
+ start_time = time.time()
63
+ predictor.set_image(image)
64
+
65
+ masks_bbox, scores_bbox, logits_bbox = predictor.predict(
66
+ box=bbox,
67
+ multimask_output=True
68
+ )
69
+
70
+ print(f"SAM Time: {time.time() - start_time:.3f}s")
71
+ out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
72
+ out_image[:, :, :3] = image
73
+ out_image_bbox = out_image.copy()
74
+ out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255
75
+ torch.cuda.empty_cache()
76
+ return Image.fromarray(out_image_bbox, mode='RGB')
77
+ # return Image.fromarray(out_image_bbox, mode='RGBA')
78
+
79
+
80
+ def expand2square(pil_img, background_color):
81
+ width, height = pil_img.size
82
+ if width == height:
83
+ return pil_img
84
+ elif width > height:
85
+ result = Image.new(pil_img.mode, (width, width), background_color)
86
+ result.paste(pil_img, (0, (width - height) // 2))
87
+ return result
88
+ else:
89
+ result = Image.new(pil_img.mode, (height, height), background_color)
90
+ result.paste(pil_img, ((height - width) // 2, 0))
91
+ return result
92
+
93
+
94
+ def preprocess(predictor, input_image, chk_group=None, segment=True):
95
+ RES = 1024
96
+ input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
97
+ if chk_group is not None:
98
+ segment = "Use SAM to center animal" in chk_group
99
+ if segment:
100
+ image_rem = input_image.convert('RGB')
101
+ arr = np.asarray(image_rem)[:,:,-1]
102
+ x_nonzero = np.nonzero(arr.sum(axis=0))
103
+ y_nonzero = np.nonzero(arr.sum(axis=1))
104
+ x_min = int(x_nonzero[0].min())
105
+ y_min = int(y_nonzero[0].min())
106
+ x_max = int(x_nonzero[0].max())
107
+ y_max = int(y_nonzero[0].max())
108
+ input_image = sam_segment(predictor, input_image.convert('RGB'), x_min, y_min, x_max, y_max)
109
+ # Rescale and recenter
110
+ # if rescale:
111
+ # image_arr = np.array(input_image)
112
+ # in_w, in_h = image_arr.shape[:2]
113
+ # out_res = min(RES, max(in_w, in_h))
114
+ # ret, mask = cv2.threshold(np.array(input_image.split()[-1]), 0, 255, cv2.THRESH_BINARY)
115
+ # x, y, w, h = cv2.boundingRect(mask)
116
+ # max_size = max(w, h)
117
+ # ratio = 0.75
118
+ # side_len = int(max_size / ratio)
119
+ # padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
120
+ # center = side_len//2
121
+ # padded_image[center-h//2:center-h//2+h, center-w//2:center-w//2+w] = image_arr[y:y+h, x:x+w]
122
+ # rgba = Image.fromarray(padded_image).resize((out_res, out_res), Image.LANCZOS)
123
+
124
+ # rgba_arr = np.array(rgba) / 255.0
125
+ # rgb = rgba_arr[...,:3] * rgba_arr[...,-1:] + (1 - rgba_arr[...,-1:])
126
+ # input_image = Image.fromarray((rgb * 255).astype(np.uint8))
127
+ # else:
128
+ # input_image = expand2square(input_image, (127, 127, 127, 0))
129
+
130
+ input_image = expand2square(input_image, (0, 0, 0))
131
+ return input_image, input_image.resize((256, 256), Image.Resampling.LANCZOS)
132
+
133
+
134
+ def save_images(images, mask_pred, mode="transparent"):
135
+ img = images[0]
136
+ mask = mask_pred[0]
137
+ img = img.clamp(0, 1)
138
+ if mask is not None:
139
+ mask = mask.clamp(0, 1)
140
+ if mode == "white":
141
+ img = img * mask + 1 * (1 - mask)
142
+ elif mode == "black":
143
+ img = img * mask + 0 * (1 - mask)
144
+ else:
145
+ img = torch.cat([img, mask[0:1]], 0)
146
+
147
+ img = img.permute(1, 2, 0).cpu().numpy()
148
+ img = Image.fromarray(np.uint8(img * 255))
149
+ return img
150
+
151
+
152
+ def get_bank_embedding(rgb, memory_bank_keys, memory_bank, model, memory_bank_topk=10, memory_bank_dim=128):
153
+ images = rgb
154
+ batch_size, num_frames, _, h0, w0 = images.shape
155
+ images = images.reshape(batch_size*num_frames, *images.shape[2:]) # 0~1
156
+ images_in = images * 2 - 1 # rescale to (-1, 1) for DINO
157
+
158
+ x = images_in
159
+ with torch.no_grad():
160
+ b, c, h, w = x.shape
161
+ model.netInstance.netEncoder._feats = []
162
+ model.netInstance.netEncoder._register_hooks([11], 'key')
163
+ #self._register_hooks([11], 'token')
164
+ x = model.netInstance.netEncoder.ViT.prepare_tokens(x)
165
+ #x = self.ViT.prepare_tokens_with_masks(x)
166
+
167
+ for blk in model.netInstance.netEncoder.ViT.blocks:
168
+ x = blk(x)
169
+ out = model.netInstance.netEncoder.ViT.norm(x)
170
+ model.netInstance.netEncoder._unregister_hooks()
171
+
172
+ ph, pw = h // model.netInstance.netEncoder.patch_size, w // model.netInstance.netEncoder.patch_size
173
+ patch_out = out[:, 1:] # first is class token
174
+ patch_out = patch_out.reshape(b, ph, pw, model.netInstance.netEncoder.vit_feat_dim).permute(0, 3, 1, 2)
175
+
176
+ patch_key = model.netInstance.netEncoder._feats[0][:,:,1:] # B, num_heads, num_patches, dim
177
+ patch_key = patch_key.permute(0, 1, 3, 2).reshape(b, model.netInstance.netEncoder.vit_feat_dim, ph, pw)
178
+
179
+ global_feat = out[:, 0]
180
+
181
+ batch_features = global_feat
182
+
183
+ batch_size = batch_features.shape[0]
184
+
185
+ query = torch.nn.functional.normalize(batch_features.unsqueeze(1), dim=-1) # [B, 1, d_k]
186
+ key = torch.nn.functional.normalize(memory_bank_keys, dim=-1) # [size, d_k]
187
+ key = key.transpose(1, 0).unsqueeze(0).repeat(batch_size, 1, 1).to(query.device) # [B, d_k, size]
188
+
189
+ cos_dist = torch.bmm(query, key).squeeze(1) # [B, size], larger the more similar
190
+ rank_idx = torch.sort(cos_dist, dim=-1, descending=True)[1][:, :memory_bank_topk] # [B, k]
191
+ value = memory_bank.unsqueeze(0).repeat(batch_size, 1, 1).to(query.device) # [B, size, d_v]
192
+
193
+ out = torch.gather(value, dim=1, index=rank_idx[..., None].repeat(1, 1, memory_bank_dim)) # [B, k, d_v]
194
+
195
+ weights = torch.gather(cos_dist, dim=-1, index=rank_idx) # [B, k]
196
+ weights = torch.nn.functional.normalize(weights, p=1.0, dim=-1).unsqueeze(-1).repeat(1, 1, memory_bank_dim) # [B, k, d_v] weights have been normalized
197
+
198
+ out = weights * out
199
+ out = torch.sum(out, dim=1)
200
+
201
+ batch_mean_out = torch.mean(out, dim=0)
202
+
203
+ weight_aux = {
204
+ 'weights': weights[:, :, 0], # [B, k], weights from large to small
205
+ 'pick_idx': rank_idx, # [B, k]
206
+ }
207
+
208
+ batch_embedding = batch_mean_out
209
+ embeddings = out
210
+ weights = weight_aux
211
+
212
+ bank_embedding_model_input = [batch_embedding, embeddings, weights]
213
+
214
+ return bank_embedding_model_input
215
+
216
+
217
+ class FixedDirectionLight(torch.nn.Module):
218
+ def __init__(self, direction, amb, diff):
219
+ super(FixedDirectionLight, self).__init__()
220
+ self.light_dir = direction
221
+ self.amb = amb
222
+ self.diff = diff
223
+ self.is_hacking = not (isinstance(self.amb, float)
224
+ or isinstance(self.amb, int))
225
+
226
+ def forward(self, feat):
227
+ batch_size = feat.shape[0]
228
+ if self.is_hacking:
229
+ return torch.concat([self.light_dir, self.amb, self.diff], -1)
230
+ else:
231
+ return torch.concat([self.light_dir, torch.FloatTensor([self.amb, self.diff]).to(self.light_dir.device)], -1).expand(batch_size, -1)
232
+
233
+ def shade(self, feat, kd, normal):
234
+ light_params = self.forward(feat)
235
+ light_dir = light_params[..., :3][:, None, None, :]
236
+ int_amb = light_params[..., 3:4][:, None, None, :]
237
+ int_diff = light_params[..., 4:5][:, None, None, :]
238
+ shading = (int_amb + int_diff *
239
+ torch.clamp(util.dot(light_dir, normal), min=0.0))
240
+ shaded = shading * kd
241
+ return shaded, shading
242
+
243
+
244
+ def render_bones(mvp, bones_pred, size=(256, 256)):
245
+ bone_world4 = torch.concat([bones_pred, torch.ones_like(bones_pred[..., :1]).to(bones_pred.device)], dim=-1)
246
+ b, f, num_bones = bone_world4.shape[:3]
247
+ bones_clip4 = (bone_world4.view(b, f, num_bones*2, 1, 4) @ mvp.transpose(-1, -2).reshape(b, f, 1, 4, 4)).view(b, f, num_bones, 2, 4)
248
+ bones_uv = bones_clip4[..., :2] / bones_clip4[..., 3:4] # b, f, num_bones, 2, 2
249
+ dpi = 32
250
+ fx, fy = size[1] // dpi, size[0] // dpi
251
+
252
+ rendered = []
253
+ for b_idx in range(b):
254
+ for f_idx in range(f):
255
+ frame_bones_uv = bones_uv[b_idx, f_idx].cpu().numpy()
256
+ fig = plt.figure(figsize=(fx, fy), dpi=dpi, frameon=False)
257
+ ax = plt.Axes(fig, [0., 0., 1., 1.])
258
+ ax.set_axis_off()
259
+ for bone in frame_bones_uv:
260
+ ax.plot(bone[:, 0], bone[:, 1], marker='o', linewidth=8, markersize=20)
261
+ ax.set_xlim(-1, 1)
262
+ ax.set_ylim(-1, 1)
263
+ ax.invert_yaxis()
264
+ # Convert to image
265
+ fig.add_axes(ax)
266
+ fig.canvas.draw_idle()
267
+ image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
268
+ w, h = fig.canvas.get_width_height()
269
+ image.resize(h, w, 3)
270
+ rendered += [image / 255.]
271
+ return torch.from_numpy(np.stack(rendered, 0).transpose(0, 3, 1, 2)).to(bones_pred.device)
272
+
273
+ def add_mesh_color(mesh, color):
274
+ verts = mesh.verts_padded()
275
+ color = torch.FloatTensor(color).to(verts.device).view(1,1,3) / 255
276
+ mesh.textures = renderer.TexturesVertex(verts_features=verts*0+color)
277
+ return mesh
278
+
279
+ def create_sphere(position, scale, device, color=[139, 149, 173]):
280
+ mesh = utils.ico_sphere(2).to(device)
281
+ mesh = mesh.extend(position.shape[0])
282
+
283
+ # scale and offset
284
+ mesh = mesh.update_padded(mesh.verts_padded() * scale + position[:, None])
285
+
286
+ mesh = add_mesh_color(mesh, color)
287
+
288
+ return mesh
289
+
290
+ def estimate_bone_rotation(b):
291
+ """
292
+ (0, 0, 1) = matmul(R^(-1), b)
293
+
294
+ assumes x, y is a symmetry plane
295
+
296
+ returns R
297
+ """
298
+ b = b / torch.norm(b, dim=-1, keepdim=True)
299
+
300
+ n = torch.FloatTensor([[1, 0, 0]]).to(b.device)
301
+ n = n.expand_as(b)
302
+ v = torch.cross(b, n, dim=-1)
303
+
304
+ R = torch.stack([n, v, b], dim=-1).transpose(-2, -1)
305
+
306
+ return R
307
+
308
+ def estimate_vector_rotation(vector_a, vector_b):
309
+ """
310
+ vector_a = matmul(R, vector_b)
311
+
312
+ returns R
313
+
314
+ https://math.stackexchange.com/questions/180418/calculate-rotation-matrix-to-align-vector-a-to-vector-b-in-3d
315
+ """
316
+ vector_a = vector_a / torch.norm(vector_a, dim=-1, keepdim=True)
317
+ vector_b = vector_b / torch.norm(vector_b, dim=-1, keepdim=True)
318
+
319
+ v = torch.cross(vector_a, vector_b, dim=-1)
320
+ c = torch.sum(vector_a * vector_b, dim=-1)
321
+
322
+ skew = torch.stack([
323
+ torch.stack([torch.zeros_like(v[..., 0]), -v[..., 2], v[..., 1]], dim=-1),
324
+ torch.stack([v[..., 2], torch.zeros_like(v[..., 0]), -v[..., 0]], dim=-1),
325
+ torch.stack([-v[..., 1], v[..., 0], torch.zeros_like(v[..., 0])], dim=-1)],
326
+ dim=-1)
327
+
328
+ R = torch.eye(3, device=vector_a.device)[None] + skew + torch.matmul(skew, skew) / (1 + c[..., None, None])
329
+
330
+ return R
331
+
332
+ def create_elipsoid(bone, scale=0.05, color=[139, 149, 173], generic_rotation_estim=True):
333
+ length = torch.norm(bone[:, 0] - bone[:, 1], dim=-1)
334
+
335
+ mesh = utils.ico_sphere(2).to(bone.device)
336
+ mesh = mesh.extend(bone.shape[0])
337
+ # scale x, y
338
+ verts = mesh.verts_padded() * torch.FloatTensor([scale, scale, 1]).to(bone.device)
339
+ # stretch along z axis, set the start to origin
340
+ verts[:, :, 2] = verts[:, :, 2] * length[:, None] * 0.5 + length[:, None] * 0.5
341
+
342
+ bone_vector = bone[:, 1] - bone[:, 0]
343
+ z_vector = torch.FloatTensor([[0, 0, 1]]).to(bone.device)
344
+ z_vector = z_vector.expand_as(bone_vector)
345
+ if generic_rotation_estim:
346
+ rot = estimate_vector_rotation(z_vector, bone_vector)
347
+ else:
348
+ rot = estimate_bone_rotation(bone_vector)
349
+ tsf = transforms.Rotate(rot, device=bone.device)
350
+ tsf = tsf.compose(transforms.Translate(bone[:, 0], device=bone.device))
351
+ verts = tsf.transform_points(verts)
352
+
353
+ mesh = mesh.update_padded(verts)
354
+
355
+ mesh = add_mesh_color(mesh, color)
356
+
357
+ return mesh
358
+
359
+ def convert_textures_vertex_to_textures_uv(meshes: structures.Meshes, color1, color2) -> renderer.TexturesUV:
360
+ """
361
+ Convert a TexturesVertex object to a TexturesUV object.
362
+ """
363
+ color1 = torch.Tensor(color1).to(meshes.device).view(1, 1, 3) / 255
364
+ color2 = torch.Tensor(color2).to(meshes.device).view(1, 1, 3) / 255
365
+ textures_vertex = meshes.textures
366
+ assert isinstance(textures_vertex, renderer.TexturesVertex), "Input meshes must have TexturesVertex"
367
+ verts_rgb = textures_vertex.verts_features_padded()
368
+ faces_uvs = meshes.faces_padded()
369
+ batch_size = verts_rgb.shape[0]
370
+ maps = torch.zeros(batch_size, 128, 128, 3, device=verts_rgb.device)
371
+ maps[:, :, :64, :] = color1
372
+ maps[:, :, 64:, :] = color2
373
+ is_first = (verts_rgb == color1)[..., 0]
374
+ verts_uvs = torch.zeros(batch_size, verts_rgb.shape[1], 2, device=verts_rgb.device)
375
+ verts_uvs[is_first] = torch.FloatTensor([0.25, 0.5]).to(verts_rgb.device)
376
+ verts_uvs[~is_first] = torch.FloatTensor([0.75, 0.5]).to(verts_rgb.device)
377
+ textures_uv = renderer.TexturesUV(maps=maps, faces_uvs=faces_uvs, verts_uvs=verts_uvs)
378
+ meshes.textures = textures_uv
379
+ return meshes
380
+
381
+ def create_bones_scene(bones, joint_color=[66, 91, 140], bone_color=[119, 144, 189], show_end_point=False):
382
+ meshes = []
383
+ for bone_i in range(bones.shape[1]):
384
+ # points
385
+ meshes += [create_sphere(bones[:, bone_i, 0], 0.1, bones.device, color=joint_color)]
386
+ if show_end_point:
387
+ meshes += [create_sphere(bones[:, bone_i, 1], 0.1, bones.device, color=joint_color)]
388
+
389
+ # connecting ellipsoid
390
+ meshes += [create_elipsoid(bones[:, bone_i], color=bone_color)]
391
+
392
+ current_batch_size = bones.shape[0]
393
+ meshes = [structures.join_meshes_as_scene([m[i] for m in meshes]) for i in range(current_batch_size)]
394
+ mesh = structures.join_meshes_as_batch(meshes)
395
+
396
+ return mesh
397
+
398
+
399
+ def run_pipeline(model_items, cfgs, input_img, device):
400
+ epoch = 999
401
+ total_iter = 999999
402
+ model = model_items[0]
403
+ memory_bank = model_items[1]
404
+ memory_bank_keys = model_items[2]
405
+
406
+ input_image = torch.stack([torchvision.transforms.ToTensor()(input_img)], dim=0).to(device)
407
+
408
+ with torch.no_grad():
409
+ model.netPrior.eval()
410
+ model.netInstance.eval()
411
+ input_image = torch.nn.functional.interpolate(input_image, size=(256, 256), mode='bilinear', align_corners=False)
412
+ input_image = input_image[:, None, :, :] # [B=1, F=1, 3, 256, 256]
413
+
414
+ bank_embedding = get_bank_embedding(
415
+ input_image,
416
+ memory_bank_keys,
417
+ memory_bank,
418
+ model,
419
+ memory_bank_topk=cfgs.get("memory_bank_topk", 10),
420
+ memory_bank_dim=128
421
+ )
422
+
423
+ prior_shape, dino_pred, classes_vectors = model.netPrior(
424
+ category_name='tmp',
425
+ perturb_sdf=False,
426
+ total_iter=total_iter,
427
+ is_training=False,
428
+ class_embedding=bank_embedding
429
+ )
430
+ Instance_out = model.netInstance(
431
+ 'tmp',
432
+ input_image,
433
+ prior_shape,
434
+ epoch,
435
+ dino_features=None,
436
+ dino_clusters=None,
437
+ total_iter=total_iter,
438
+ is_training=False
439
+ ) # frame dim collapsed N=(B*F)
440
+ if len(Instance_out) == 13:
441
+ shape, pose_raw, pose, mvp, w2c, campos, texture_pred, im_features, dino_feat_im_calc, deform, all_arti_params, light, forward_aux = Instance_out
442
+ im_features_map = None
443
+ else:
444
+ shape, pose_raw, pose, mvp, w2c, campos, texture_pred, im_features, dino_feat_im_calc, deform, all_arti_params, light, forward_aux, im_features_map = Instance_out
445
+
446
+ class_vector = classes_vectors # the bank embeddings
447
+
448
+ gray_light = FixedDirectionLight(direction=torch.FloatTensor([0, 0, 1]).to(device), amb=0.2, diff=0.7)
449
+
450
+ image_pred, mask_pred, _, _, _, shading = model.render(
451
+ shape, texture_pred, mvp, w2c, campos, 256, background=model.background_mode,
452
+ im_features=im_features, light=gray_light, prior_shape=prior_shape, render_mode='diffuse',
453
+ render_flow=False, dino_pred=None, im_features_map=im_features_map
454
+ )
455
+ mask_pred = mask_pred.expand_as(image_pred)
456
+ shading = shading.expand_as(image_pred)
457
+ # render bones in pytorch3D style
458
+ posed_bones = forward_aux["posed_bones"].squeeze(1)
459
+ jc, bc = [66, 91, 140], [119, 144, 189]
460
+ bones_meshes = create_bones_scene(posed_bones, joint_color=jc, bone_color=bc, show_end_point=True)
461
+ bones_meshes = convert_textures_vertex_to_textures_uv(bones_meshes, color1=jc, color2=bc)
462
+ nv_meshes = make_mesh(verts=bones_meshes.verts_padded(), faces=bones_meshes.faces_padded()[0:1],
463
+ uvs=bones_meshes.textures.verts_uvs_padded(), uv_idx=bones_meshes.textures.faces_uvs_padded()[0:1],
464
+ material=material_texture.Texture2D(bones_meshes.textures.maps_padded()))
465
+ buffers = render_mesh(dr.RasterizeGLContext(), nv_meshes, mvp, w2c, campos, nv_meshes.material, lgt=gray_light, feat=im_features, dino_pred=None, resolution=256, bsdf="diffuse")
466
+
467
+ shaded = buffers["shaded"].permute(0, 3, 1, 2)
468
+ bone_image = shaded[:, :3, :, :]
469
+ bone_mask = shaded[:, 3:, :, :]
470
+ mask_final = mask_pred.logical_or(bone_mask)
471
+ mask_final = mask_final.int()
472
+ image_with_bones = bone_image * bone_mask * 0.5 + (shading * (1 - bone_mask * 0.5) + 0.5 * (mask_final.float() - mask_pred.float()))
473
+
474
+ mesh_image = save_images(shading, mask_pred)
475
+ mesh_bones_image = save_images(image_with_bones, mask_final)
476
+
477
+ final_shape = shape.clone()
478
+ prior_shape = prior_shape.clone()
479
+
480
+ final_mesh_tri = trimesh.Trimesh(
481
+ vertices=final_shape.v_pos[0].detach().cpu().numpy(),
482
+ faces=final_shape.t_pos_idx[0].detach().cpu().numpy(),
483
+ process=False,
484
+ maintain_order=True)
485
+ prior_mesh_tri = trimesh.Trimesh(
486
+ vertices=prior_shape.v_pos[0].detach().cpu().numpy(),
487
+ faces=prior_shape.t_pos_idx[0].detach().cpu().numpy(),
488
+ process=False,
489
+ maintain_order=True)
490
+
491
+
492
+
493
+ def run_demo():
494
+ parser = argparse.ArgumentParser()
495
+ parser.add_argument('--gpu', default='0', type=str,
496
+ help='Specify a GPU device')
497
+ parser.add_argument('--num_workers', default=4, type=int,
498
+ help='Specify the number of worker threads for data loaders')
499
+ parser.add_argument('--seed', default=0, type=int,
500
+ help='Specify a random seed')
501
+ parser.add_argument('--config', default='./ckpts/configs.yml',
502
+ type=str) # Model config path
503
+ parser.add_argument('--checkpoint_path', default='./ckpts/iter0800000.pth', type=str)
504
+
505
+ args = parser.parse_args()
506
+
507
+ torch.manual_seed(args.seed)
508
+ os.environ['MASTER_ADDR'] = 'localhost'
509
+ os.environ['MASTER_PORT'] = '8088'
510
+ dist.init_process_group("gloo", rank=_GPU_ID, world_size=1)
511
+ torch.cuda.set_device(_GPU_ID)
512
+ args.rank = _GPU_ID
513
+ args.world_size = 1
514
+ args.gpu = os.environ['CUDA_VISIBLE_DEVICES']
515
+ device = f'cuda:{_GPU_ID}'
516
+
517
+ resolution = (256, 256)
518
+ batch_size = 1
519
+ model_cfgs = setup_runtime(args)
520
+ bone_y_thresh = 0.4
521
+ body_bone_idx_preset = [3, 6, 6, 3]
522
+ model_cfgs['body_bone_idx_preset'] = body_bone_idx_preset
523
+
524
+ model = Unsup3DDDP(model_cfgs)
525
+ # a hack attempt
526
+ model.netPrior.classes_vectors = torch.nn.Parameter(torch.nn.init.uniform_(torch.empty(123, 128), a=-0.05, b=0.05))
527
+ cp = torch.load(args.checkpoint_path, map_location=device)
528
+ model.load_model_state(cp)
529
+ memory_bank_keys = cp['memory_bank_keys']
530
+ memory_bank = cp['memory_bank']
531
+
532
+ model.to(device)
533
+ memory_bank.to(device)
534
+ memory_bank_keys.to(device)
535
+ model_items = [
536
+ model,
537
+ memory_bank,
538
+ memory_bank_keys
539
+ ]
540
+
541
+ predictor = sam_init()
542
+
543
+ custom_theme = gr.themes.Soft(primary_hue="blue").set(
544
+ button_secondary_background_fill="*neutral_100",
545
+ button_secondary_background_fill_hover="*neutral_200")
546
+ custom_css = '''#disp_image {
547
+ text-align: center; /* Horizontally center the content */
548
+ }'''
549
+
550
+ with gr.Blocks(title=_TITLE, theme=custom_theme, css=custom_css) as demo:
551
+ with gr.Row():
552
+ with gr.Column(scale=1):
553
+ gr.Markdown('# ' + _TITLE)
554
+ gr.Markdown(_DESCRIPTION)
555
+ with gr.Row(variant='panel'):
556
+ with gr.Column(scale=1):
557
+ input_image = gr.Image(type='pil', image_mode='RGBA', height=256, label='Input image', tool=None)
558
+
559
+ example_folder = os.path.join(os.path.dirname(__file__), "./example_images")
560
+ example_fns = [os.path.join(example_folder, example) for example in os.listdir(example_folder)]
561
+ gr.Examples(
562
+ examples=example_fns,
563
+ inputs=[input_image],
564
+ # outputs=[input_image],
565
+ cache_examples=False,
566
+ label='Examples (click one of the images below to start)',
567
+ examples_per_page=30
568
+ )
569
+ with gr.Column(scale=1):
570
+ processed_image = gr.Image(type='pil', label="Processed Image", interactive=False, height=256, tool=None, image_mode='RGB', elem_id="disp_image")
571
+ processed_image_highres = gr.Image(type='pil', image_mode='RGB', visible=False, tool=None)
572
+
573
+ with gr.Accordion('Advanced options', open=True):
574
+ with gr.Row():
575
+ with gr.Column():
576
+ input_processing = gr.CheckboxGroup(['Use SAM to center animal'],
577
+ label='Input Image Preprocessing',
578
+ value=['Use SAM to center animal'],
579
+ info='untick this, if animal is already centered, e.g. in example images')
580
+ # with gr.Column():
581
+ # output_processing = gr.CheckboxGroup(['Background Removal'], label='Output Image Postprocessing', value=[])
582
+ # with gr.Row():
583
+ # with gr.Column():
584
+ # scale_slider = gr.Slider(1, 5, value=3, step=1,
585
+ # label='Classifier Free Guidance Scale')
586
+ # with gr.Column():
587
+ # steps_slider = gr.Slider(15, 100, value=50, step=1,
588
+ # label='Number of Diffusion Inference Steps')
589
+ # with gr.Row():
590
+ # with gr.Column():
591
+ # seed = gr.Number(42, label='Seed')
592
+ # with gr.Column():
593
+ # crop_size = gr.Number(192, label='Crop size')
594
+ # crop_size = 192
595
+ run_btn = gr.Button('Generate', variant='primary', interactive=True)
596
+ with gr.Row():
597
+ view_1 = gr.Image(interactive=False, height=256, show_label=False)
598
+ view_2 = gr.Image(interactive=False, height=256, show_label=False)
599
+ with gr.Row():
600
+ shape_1 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="Reconstructed Model")
601
+ shape_2 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="Bank Base Shape Model")
602
+
603
+ with gr.Row():
604
+ view_gallery = gr.Gallery(interactive=False, show_label=False, container=True, preview=True, allow_preview=False, height=1200)
605
+ normal_gallery = gr.Gallery(interactive=False, show_label=False, container=True, preview=True, allow_preview=False, height=1200)
606
+
607
+
608
+ run_btn.click(fn=partial(preprocess, predictor),
609
+ inputs=[input_image, input_processing],
610
+ outputs=[processed_image_highres, processed_image], queue=True
611
+ ).success(fn=partial(run_pipeline, model_items, model_cfgs),
612
+ inputs=[processed_image, device],
613
+ outputs=[view_1, view_2, shape_1, shape_2]
614
+ )
615
+ demo.queue().launch(share=True, max_threads=80)
616
+
617
+
618
+ if __name__ == '__main__':
619
+ fire.Fire(run_demo)
requirements.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip==1.0
2
+ ConfigArgParse==1.5.3
3
+ core==1.0.1
4
+ diffusers==0.20.0
5
+ einops==0.4.1
6
+ faiss==1.7.3
7
+ fire==0.5.0
8
+ glfw==2.5.7
9
+ gradio==4.12.0
10
+ imageio==2.27.0
11
+ ipdb==0.13.9
12
+ lpips==0.1.4
13
+ matplotlib==3.8.1
14
+ numpy==1.23.1
15
+ nvdiffrast==0.3.0
16
+ Pillow==9.2.0
17
+ Pillow==10.1.0
18
+ PyOpenGL==3.1.6
19
+ PyOpenGL==3.1.7
20
+ pytorch3d==0.7.2
21
+ PyYAML==6.0
22
+ PyYAML==6.0.1
23
+ scipy==1.9.1
24
+ segment_anything==1.0
25
+ siren_pytorch==0.1.7
26
+ tinycudann==1.7
27
+ torch==1.10.0
28
+ torchvision==0.11.0
29
+ transformers==4.28.1
30
+ trimesh==4.0.0
31
+ wandb==0.14.2
32
+ xatlas==0.0.7