multimodalart HF staff commited on
Commit
d55e5c1
1 Parent(s): b8d159a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -28,11 +28,13 @@ args = tyro.cli(ArgumentConfig)
28
  # specify configs for inference
29
  inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
30
  crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
 
31
  gradio_pipeline = GradioPipeline(
32
  inference_cfg=inference_cfg,
33
  crop_cfg=crop_cfg,
34
  args=args
35
  )
 
36
  # assets
37
  title_md = "assets/gradio_title.md"
38
  example_portrait_dir = "assets/examples/source"
@@ -100,8 +102,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
100
  flag_do_crop_input,
101
  flag_remap_input
102
  ],
 
103
  examples_per_page=5,
104
- cache_examples="lazy"
 
105
  )
106
  gr.Markdown(load_description("assets/gradio_description_retargeting.md"))
107
  with gr.Row():
@@ -137,7 +141,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
137
  show_progress=True
138
  )
139
  process_button_animation.click(
140
- fn=gradio_pipeline.execute_video,
141
  inputs=[
142
  image_input,
143
  video_input,
 
28
  # specify configs for inference
29
  inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
30
  crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
31
+
32
  gradio_pipeline = GradioPipeline(
33
  inference_cfg=inference_cfg,
34
  crop_cfg=crop_cfg,
35
  args=args
36
  )
37
+ gradio_pipeline.execute_video
38
  # assets
39
  title_md = "assets/gradio_title.md"
40
  example_portrait_dir = "assets/examples/source"
 
102
  flag_do_crop_input,
103
  flag_remap_input
104
  ],
105
+ outputs=[output_image, output_image_paste_back],
106
  examples_per_page=5,
107
+ cache_examples="lazy",
108
+ fn=lambda *args: spaces.GPU()(gradio_pipeline.execute_video)(*args),
109
  )
110
  gr.Markdown(load_description("assets/gradio_description_retargeting.md"))
111
  with gr.Row():
 
141
  show_progress=True
142
  )
143
  process_button_animation.click(
144
+ fn=lambda *args: spaces.GPU()(gradio_pipeline.execute_video)(*args),
145
  inputs=[
146
  image_input,
147
  video_input,