wzhouxiff commited on
Commit
6ccbde2
1 Parent(s): 5f3a61f

update global variance traj_list and camera_dict

Browse files
Files changed (2) hide show
  1. app.py +8 -108
  2. gradio_utils/page_control.py +111 -12
app.py CHANGED
@@ -13,9 +13,8 @@ from omegaconf import OmegaConf
13
  from PIL import Image
14
  from pytorch_lightning import seed_everything
15
 
16
- from gradio_utils.camera_utils import CAMERA_MOTION_MODE, process_camera, create_relative
17
- from gradio_utils.traj_utils import (OBJECT_MOTION_MODE, get_provided_traj,
18
- process_points, process_traj)
19
  from gradio_utils.utils import vis_camera
20
  from lvdm.models.samplers.ddim import DDIMSampler
21
  from main.evaluation.motionctrl_inference import (DEFAULT_NEGATIVE_PROMPT,
@@ -23,7 +22,8 @@ from main.evaluation.motionctrl_inference import (DEFAULT_NEGATIVE_PROMPT,
23
  post_prompt)
24
  from utils.utils import instantiate_from_config
25
 
26
- from gradio_utils.page_control import (MODE, BASE_MODEL, traj_list, camera_dict,
 
27
  reset_camera,
28
  visualized_step1, visualized_step2,
29
  visualized_camera_poses, visualized_traj_poses,
@@ -31,7 +31,8 @@ from gradio_utils.page_control import (MODE, BASE_MODEL, traj_list, camera_dict,
31
  input_raw_camera_pose,
32
  change_camera_mode, change_camera_speed,
33
  add_traj_point, add_provided_traj,
34
- fn_traj_droplast, fn_traj_reset)
 
35
 
36
  os.environ['KMP_DUPLICATE_LIB_OK']='True'
37
  SPACE_ID = os.environ.get('SPACE_ID', '')
@@ -140,107 +141,6 @@ for i in range(0, 16):
140
  res.append(RT)
141
 
142
  fig = vis_camera(res)
143
-
144
- def fn_vis_camera(info_mode, camera_args=None):
145
- global camera_dict
146
- RT = process_camera(camera_dict, camera_args) # [t, 3, 4]
147
-
148
- rescale_T = 1.0
149
- rescale_T = max(rescale_T, np.max(np.abs(RT[:,:,-1])) / 1.9)
150
-
151
- fig = vis_camera(create_relative(RT), rescale_T=rescale_T)
152
-
153
- if info_mode == MODE[0]:
154
- vis_step3_prompt_generate = True
155
- vis_prompt = True
156
- vis_num_samples = True
157
- vis_seed = True
158
- vis_start = True
159
- vis_gen_video = True
160
-
161
- vis_object_mode = False
162
- vis_object_info = False
163
-
164
- else:
165
- vis_step3_prompt_generate = False
166
- vis_prompt = False
167
- vis_num_samples = False
168
- vis_seed = False
169
- vis_start = False
170
- vis_gen_video = False
171
-
172
- vis_object_mode = True
173
- vis_object_info = True
174
-
175
- return fig, \
176
- gr.update(visible=vis_object_mode), \
177
- gr.update(visible=vis_object_info), \
178
- gr.update(visible=vis_step3_prompt_generate), \
179
- gr.update(visible=vis_prompt), \
180
- gr.update(visible=vis_num_samples), \
181
- gr.update(visible=vis_seed), \
182
- gr.update(visible=vis_start), \
183
- gr.update(visible=vis_gen_video, value=None)
184
-
185
- def fn_vis_traj():
186
- global traj_list
187
- global exp_no
188
- xy_range = 1024
189
- points = process_points(traj_list)
190
- imgs = []
191
- for idx in range(16):
192
- bg_img = np.ones((1024, 1024, 3), dtype=np.uint8) * 255
193
- for i in range(15):
194
- p = points[i]
195
- p1 = points[i+1]
196
- cv2.line(bg_img, p, p1, (255, 0, 0), 2)
197
-
198
- if i == idx:
199
- cv2.circle(bg_img, p, 2, (0, 255, 0), 20)
200
-
201
- if idx==(15):
202
- cv2.circle(bg_img, points[-1], 2, (0, 255, 0), 20)
203
-
204
- imgs.append(bg_img.astype(np.uint8))
205
-
206
- # size = (512, 512)
207
- fps = 10
208
-
209
- out_dir = f'./results_trajs/{exp_no}'
210
- os.makedirs(out_dir, exist_ok=True)
211
- exp_no += 1
212
-
213
- traj_flow = process_traj(traj_list).transpose(3,0,1,2)
214
-
215
- np.save(f'{out_dir}/traj_flow.npy', traj_flow)
216
- with open(f'{out_dir}/traj_list.txt', 'w') as f:
217
- for item in traj_list:
218
- f.write(f"{item[0]}, {item[1]}\n")
219
-
220
- if out_dir is None:
221
- path = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name
222
- else:
223
- path = os.path.join(out_dir, 'traj.mp4')
224
- writer = imageio.get_writer(path, format='mp4', mode='I', fps=fps)
225
- for img in imgs:
226
- writer.append_data(img)
227
-
228
- writer.close()
229
-
230
- vis_step3_prompt_generate = True
231
- vis_prompt = True
232
- vis_num_samples = True
233
- vis_seed = True
234
- vis_start = True
235
- vis_gen_video = True
236
- return path, gr.update(visible=vis_step3_prompt_generate), \
237
- gr.update(visible=vis_prompt), \
238
- gr.update(visible=vis_num_samples), \
239
- gr.update(visible=vis_seed), \
240
- gr.update(visible=vis_start), \
241
- gr.update(visible=vis_gen_video, value=None)
242
-
243
-
244
 
245
 
246
  ###########################################
@@ -274,8 +174,8 @@ if torch.cuda.is_available():
274
  model_v2.eval()
275
 
276
  def model_run(prompts, choose_model, infer_mode, seed, n_samples, camera_args=None):
277
- global traj_list
278
- global camera_dict
279
 
280
  RT = process_camera(camera_dict, camera_args).reshape(-1,12)
281
  traj_flow = process_traj(traj_list).transpose(3,0,1,2)
 
13
  from PIL import Image
14
  from pytorch_lightning import seed_everything
15
 
16
+ from gradio_utils.camera_utils import CAMERA_MOTION_MODE, process_camera
17
+ from gradio_utils.traj_utils import (OBJECT_MOTION_MODE, process_traj)
 
18
  from gradio_utils.utils import vis_camera
19
  from lvdm.models.samplers.ddim import DDIMSampler
20
  from main.evaluation.motionctrl_inference import (DEFAULT_NEGATIVE_PROMPT,
 
22
  post_prompt)
23
  from utils.utils import instantiate_from_config
24
 
25
+ from gradio_utils.page_control import (MODE, BASE_MODEL,
26
+ get_camera_dict, get_traj_list,
27
  reset_camera,
28
  visualized_step1, visualized_step2,
29
  visualized_camera_poses, visualized_traj_poses,
 
31
  input_raw_camera_pose,
32
  change_camera_mode, change_camera_speed,
33
  add_traj_point, add_provided_traj,
34
+ fn_traj_droplast, fn_traj_reset,
35
+ fn_vis_camera, fn_vis_traj,)
36
 
37
  os.environ['KMP_DUPLICATE_LIB_OK']='True'
38
  SPACE_ID = os.environ.get('SPACE_ID', '')
 
141
  res.append(RT)
142
 
143
  fig = vis_camera(res)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
 
146
  ###########################################
 
174
  model_v2.eval()
175
 
176
  def model_run(prompts, choose_model, infer_mode, seed, n_samples, camera_args=None):
177
+ traj_list = get_traj_list()
178
+ camera_dict = get_camera_dict()
179
 
180
  RT = process_camera(camera_dict, camera_args).reshape(-1,12)
181
  traj_flow = process_traj(traj_list).transpose(3,0,1,2)
gradio_utils/page_control.py CHANGED
@@ -1,6 +1,11 @@
 
 
 
 
1
  import gradio as gr
2
- from gradio_utils.camera_utils import CAMERA_MOTION_MODE
3
- from gradio_utils.traj_utils import get_provided_traj
 
4
 
5
  MODE = ["control camera poses", "control object trajectory", "control both camera and object motion"]
6
 
@@ -17,6 +22,8 @@ def display_camera_info(camera_dict, camera_mode=None):
17
  res += f"mode : {camera_dict['mode']}. "
18
  return res
19
 
 
 
20
  traj_list = []
21
  camera_dict = {
22
  "motion":[],
@@ -25,8 +32,16 @@ camera_dict = {
25
  "complex": None
26
  }
27
 
 
 
 
 
 
 
 
 
28
  def reset_camera():
29
- # global camera_dict
30
  camera_dict = {
31
  "motion":[],
32
  "mode": "Customized Mode 1: First A then B",
@@ -36,7 +51,7 @@ def reset_camera():
36
  return display_camera_info(camera_dict)
37
 
38
  def fn_traj_reset():
39
- # global traj_list
40
  traj_list = []
41
  return "Click to specify trajectory"
42
 
@@ -478,7 +493,7 @@ def visualized_traj_poses(step2_object_motion):
478
  gr.update(visible=vis_gen_video)
479
 
480
  def add_camera_motion(camera_motion, camera_mode):
481
- # global camera_dict
482
  if camera_dict['complex'] is not None:
483
  camera_dict['complex'] = None
484
  if camera_mode == CAMERA_MOTION_MODE[2] and len(camera_dict['motion']) <2:
@@ -489,7 +504,7 @@ def add_camera_motion(camera_motion, camera_mode):
489
  return display_camera_info(camera_dict, camera_mode)
490
 
491
  def add_complex_camera_motion(camera_motion):
492
- # global camera_dict
493
  camera_dict['complex']=camera_motion
494
  return display_camera_info(camera_dict)
495
 
@@ -522,7 +537,7 @@ def change_camera_mode(combine_type, camera_mode):
522
  gr.update(visible=vis_combine3_des)
523
 
524
  def input_raw_camera_pose(combine_type, camera_mode):
525
- # global camera_dict
526
  camera_dict['mode'] = combine_type
527
 
528
  vis_U = False
@@ -549,26 +564,26 @@ def input_raw_camera_pose(combine_type, camera_mode):
549
  gr.update(visible=vis_combine3_des)
550
 
551
  def change_camera_speed(camera_speed):
552
- # global camera_dict
553
  camera_dict['speed'] = camera_speed
554
  return display_camera_info(camera_dict)
555
 
556
  def add_traj_point(evt: gr.SelectData, ):
557
- # global traj_list
558
  traj_list.append(evt.index)
559
  traj_str = [f"{traj}" for traj in traj_list]
560
  return ", ".join(traj_str)
561
 
562
  def add_provided_traj(traj_name):
563
- # global traj_list
 
564
  traj_list = get_provided_traj(traj_name)
565
  traj_str = [f"{traj}" for traj in traj_list]
566
  return ", ".join(traj_str)
567
 
568
 
569
  def fn_traj_droplast():
570
- # global traj_list
571
-
572
  if traj_list:
573
  traj_list.pop()
574
 
@@ -578,3 +593,87 @@ def fn_traj_droplast():
578
  else:
579
  return "Click to specify trajectory"
580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import tempfile
4
+ import imageio
5
  import gradio as gr
6
+ from gradio_utils.camera_utils import CAMERA_MOTION_MODE, process_camera, create_relative
7
+ from gradio_utils.traj_utils import get_provided_traj, process_points
8
+ from gradio_utils.utils import vis_camera
9
 
10
  MODE = ["control camera poses", "control object trajectory", "control both camera and object motion"]
11
 
 
22
  res += f"mode : {camera_dict['mode']}. "
23
  return res
24
 
25
+ global traj_list, camera_dict
26
+
27
  traj_list = []
28
  camera_dict = {
29
  "motion":[],
 
32
  "complex": None
33
  }
34
 
35
+ def get_traj_list():
36
+ global traj_list
37
+ return traj_list
38
+
39
+ def get_camera_dict():
40
+ global camera_dict
41
+ return camera_dict
42
+
43
  def reset_camera():
44
+ global camera_dict
45
  camera_dict = {
46
  "motion":[],
47
  "mode": "Customized Mode 1: First A then B",
 
51
  return display_camera_info(camera_dict)
52
 
53
  def fn_traj_reset():
54
+ global traj_list
55
  traj_list = []
56
  return "Click to specify trajectory"
57
 
 
493
  gr.update(visible=vis_gen_video)
494
 
495
  def add_camera_motion(camera_motion, camera_mode):
496
+ global camera_dict
497
  if camera_dict['complex'] is not None:
498
  camera_dict['complex'] = None
499
  if camera_mode == CAMERA_MOTION_MODE[2] and len(camera_dict['motion']) <2:
 
504
  return display_camera_info(camera_dict, camera_mode)
505
 
506
  def add_complex_camera_motion(camera_motion):
507
+ global camera_dict
508
  camera_dict['complex']=camera_motion
509
  return display_camera_info(camera_dict)
510
 
 
537
  gr.update(visible=vis_combine3_des)
538
 
539
  def input_raw_camera_pose(combine_type, camera_mode):
540
+ global camera_dict
541
  camera_dict['mode'] = combine_type
542
 
543
  vis_U = False
 
564
  gr.update(visible=vis_combine3_des)
565
 
566
  def change_camera_speed(camera_speed):
567
+ global camera_dict
568
  camera_dict['speed'] = camera_speed
569
  return display_camera_info(camera_dict)
570
 
571
  def add_traj_point(evt: gr.SelectData, ):
572
+ global traj_list
573
  traj_list.append(evt.index)
574
  traj_str = [f"{traj}" for traj in traj_list]
575
  return ", ".join(traj_str)
576
 
577
  def add_provided_traj(traj_name):
578
+ global traj_list
579
+ # import pdb; pdb.set_trace()
580
  traj_list = get_provided_traj(traj_name)
581
  traj_str = [f"{traj}" for traj in traj_list]
582
  return ", ".join(traj_str)
583
 
584
 
585
  def fn_traj_droplast():
586
+ global traj_list
 
587
  if traj_list:
588
  traj_list.pop()
589
 
 
593
  else:
594
  return "Click to specify trajectory"
595
 
596
+ def fn_vis_camera(info_mode, camera_args=None):
597
+ global camera_dict
598
+ RT = process_camera(camera_dict, camera_args) # [t, 3, 4]
599
+
600
+ rescale_T = 1.0
601
+ rescale_T = max(rescale_T, np.max(np.abs(RT[:,:,-1])) / 1.9)
602
+
603
+ fig = vis_camera(create_relative(RT), rescale_T=rescale_T)
604
+
605
+ if info_mode == MODE[0]:
606
+ vis_step3_prompt_generate = True
607
+ vis_prompt = True
608
+ vis_num_samples = True
609
+ vis_seed = True
610
+ vis_start = True
611
+ vis_gen_video = True
612
+
613
+ vis_object_mode = False
614
+ vis_object_info = False
615
+
616
+ else:
617
+ vis_step3_prompt_generate = False
618
+ vis_prompt = False
619
+ vis_num_samples = False
620
+ vis_seed = False
621
+ vis_start = False
622
+ vis_gen_video = False
623
+
624
+ vis_object_mode = True
625
+ vis_object_info = True
626
+
627
+ return fig, \
628
+ gr.update(visible=vis_object_mode), \
629
+ gr.update(visible=vis_object_info), \
630
+ gr.update(visible=vis_step3_prompt_generate), \
631
+ gr.update(visible=vis_prompt), \
632
+ gr.update(visible=vis_num_samples), \
633
+ gr.update(visible=vis_seed), \
634
+ gr.update(visible=vis_start), \
635
+ gr.update(visible=vis_gen_video, value=None)
636
+
637
+ def fn_vis_traj():
638
+ # import pdb; pdb.set_trace()
639
+ # global traj_list
640
+ # xy_range = 1024
641
+ # print(traj_list)
642
+ global traj_list
643
+ print(traj_list)
644
+ points = process_points(traj_list)
645
+ imgs = []
646
+ for idx in range(16):
647
+ bg_img = np.ones((1024, 1024, 3), dtype=np.uint8) * 255
648
+ for i in range(15):
649
+ p = points[i]
650
+ p1 = points[i+1]
651
+ cv2.line(bg_img, p, p1, (255, 0, 0), 2)
652
+
653
+ if i == idx:
654
+ cv2.circle(bg_img, p, 2, (0, 255, 0), 20)
655
+
656
+ if idx==(15):
657
+ cv2.circle(bg_img, points[-1], 2, (0, 255, 0), 20)
658
+
659
+ imgs.append(bg_img.astype(np.uint8))
660
+
661
+ path = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name
662
+ writer = imageio.get_writer(path, format='mp4', mode='I', fps=10)
663
+ for img in imgs:
664
+ writer.append_data(img)
665
+ writer.close()
666
+
667
+ vis_step3_prompt_generate = True
668
+ vis_prompt = True
669
+ vis_num_samples = True
670
+ vis_seed = True
671
+ vis_start = True
672
+ vis_gen_video = True
673
+ return path, gr.update(visible=vis_step3_prompt_generate), \
674
+ gr.update(visible=vis_prompt), \
675
+ gr.update(visible=vis_num_samples), \
676
+ gr.update(visible=vis_seed), \
677
+ gr.update(visible=vis_start), \
678
+ gr.update(visible=vis_gen_video, value=None)
679
+