fffiloni commited on
Commit
ae33a9e
1 Parent(s): 3649171

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -13,6 +13,8 @@
13
  # limitations under the License.
14
 
15
  import gradio as gr
 
 
16
  import argparse
17
  import inspect
18
  import os
@@ -1616,22 +1618,20 @@ if __name__ == "__main__":
1616
  parser.add_argument('--experiment_name', default="AccDiffusion")
1617
 
1618
  args = parser.parse_args()
1619
-
 
 
 
1620
  # GRADIO MODE
1621
 
1622
- def infer(prompt, progress=gr.Progress(track_tqdm=True)):
 
1623
  set_seed(args.seed)
1624
  width,height = list(map(int, args.resolution.split(',')))
1625
- pipe = AccDiffusionSDXLPipeline.from_pretrained(args.model_ckpt, torch_dtype=torch.float16).to("cuda")
1626
- generator = torch.Generator(device='cuda')
1627
- generator = generator.manual_seed(args.seed)
1628
  cross_attention_kwargs = {"edit_type": "visualize",
1629
  "n_self_replace": 0.4,
1630
  "n_cross_replace": {"default_": 1.0, "confetti": 0.8},
1631
  }
1632
-
1633
-
1634
-
1635
  seed = args.seed
1636
  generator = generator.manual_seed(seed)
1637
 
@@ -1644,8 +1644,8 @@ if __name__ == "__main__":
1644
  view_batch_size=args.view_batch_size,
1645
  stride=args.stride,
1646
  cross_attention_kwargs=cross_attention_kwargs,
1647
- num_inference_steps=args.num_inference_steps,
1648
- guidance_scale = 7.5,
1649
  multi_guidance_scale = args.multi_guidance_scale,
1650
  cosine_scale_1=args.cosine_scale_1,
1651
  cosine_scale_2=args.cosine_scale_2,
@@ -1680,7 +1680,7 @@ if __name__ == "__main__":
1680
  <img src='https://img.shields.io/badge/Project-Page-blue'>
1681
  </a>
1682
  <a href='https://github.com/lzhxmu/AccDiffusion'>
1683
- <img src='https://img.shields.io/badge/Code-blue'>
1684
  </a>
1685
  <a href='https://arxiv.org/abs/2407.10738v1'>
1686
  <img src='https://img.shields.io/badge/Paper-Arxiv-red'>
@@ -1688,6 +1688,9 @@ if __name__ == "__main__":
1688
  </div>
1689
  """)
1690
  prompt = gr.Textbox(label="Prompt")
 
 
 
1691
  submit_btn = gr.Button("Submit")
1692
  output_images = gr.Image(format="png")
1693
  gr.Examples(
@@ -1700,7 +1703,7 @@ if __name__ == "__main__":
1700
  )
1701
  submit_btn.click(
1702
  fn = infer,
1703
- inputs = [prompt],
1704
  outputs = [output_images],
1705
  show_api=False
1706
  )
 
13
  # limitations under the License.
14
 
15
  import gradio as gr
16
+ import spaces
17
+
18
  import argparse
19
  import inspect
20
  import os
 
1618
  parser.add_argument('--experiment_name', default="AccDiffusion")
1619
 
1620
  args = parser.parse_args()
1621
+
1622
+ pipe = AccDiffusionSDXLPipeline.from_pretrained(args.model_ckpt, torch_dtype=torch.float16).to("cuda")
1623
+ generator = torch.Generator(device='cuda')
1624
+
1625
  # GRADIO MODE
1626
 
1627
+ @spaces.GPU()
1628
+ def infer(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
1629
  set_seed(args.seed)
1630
  width,height = list(map(int, args.resolution.split(',')))
 
 
 
1631
  cross_attention_kwargs = {"edit_type": "visualize",
1632
  "n_self_replace": 0.4,
1633
  "n_cross_replace": {"default_": 1.0, "confetti": 0.8},
1634
  }
 
 
 
1635
  seed = args.seed
1636
  generator = generator.manual_seed(seed)
1637
 
 
1644
  view_batch_size=args.view_batch_size,
1645
  stride=args.stride,
1646
  cross_attention_kwargs=cross_attention_kwargs,
1647
+ num_inference_steps=num_inference_steps,
1648
+ guidance_scale = guidance_scale,
1649
  multi_guidance_scale = args.multi_guidance_scale,
1650
  cosine_scale_1=args.cosine_scale_1,
1651
  cosine_scale_2=args.cosine_scale_2,
 
1680
  <img src='https://img.shields.io/badge/Project-Page-blue'>
1681
  </a>
1682
  <a href='https://github.com/lzhxmu/AccDiffusion'>
1683
+ <img src='https://img.shields.io/badge/Code-github-blue'>
1684
  </a>
1685
  <a href='https://arxiv.org/abs/2407.10738v1'>
1686
  <img src='https://img.shields.io/badge/Paper-Arxiv-red'>
 
1688
  </div>
1689
  """)
1690
  prompt = gr.Textbox(label="Prompt")
1691
+ with gr.Accordion("Advanced settings", open=False):
1692
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=2, maximum=50, step=1, value=50)
1693
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=510, step=0.1, value=7.5)
1694
  submit_btn = gr.Button("Submit")
1695
  output_images = gr.Image(format="png")
1696
  gr.Examples(
 
1703
  )
1704
  submit_btn.click(
1705
  fn = infer,
1706
+ inputs = [prompt, num_inference_steps, guidance_scale],
1707
  outputs = [output_images],
1708
  show_api=False
1709
  )