h1t commited on
Commit
1920e68
1 Parent(s): 2338068

change layout

Browse files
Files changed (1) hide show
  1. app.py +49 -22
app.py CHANGED
@@ -35,6 +35,7 @@ class GradioDemo:
35
  num_inference_steps = 4,
36
  sd_pipe_guidance_scale = 1.0,
37
  seed = 1024,
 
38
  ):
39
  pipe_kwargs = dict(
40
  prompt = prompt,
@@ -51,18 +52,29 @@ class GradioDemo:
51
  )['images'][0]
52
 
53
  generator = torch.Generator(device=self.pipe.device).manual_seed(seed)
54
- pipe_kwargs.update(oms_flag=True, oms_prompt=oms_prompt, oms_guidance_scale=1.0)
55
- print(f'w/ oms wo/ cfg kwargs: {pipe_kwargs}')
56
- image_oms = self.pipe(
57
  **pipe_kwargs,
58
  generator=generator
59
  )['images'][0]
60
 
 
 
 
 
 
 
 
 
 
 
 
61
  oms_guidance_flag = oms_guidance_scale != 1.0
62
  if oms_guidance_flag:
63
  generator = torch.Generator(device=self.pipe.device).manual_seed(seed)
64
  pipe_kwargs.update(oms_guidance_scale=oms_guidance_scale)
65
- print(f'w/ oms +cfg kwargs: {pipe_kwargs}')
66
  image_oms_cfg = self.pipe(
67
  **pipe_kwargs,
68
  generator=generator
@@ -70,29 +82,44 @@ class GradioDemo:
70
  else:
71
  image_oms_cfg = None
72
 
73
- return image_raw, image_oms, image_oms_cfg, gr.update(visible=oms_guidance_flag)
74
 
75
  def mainloop(self):
76
  with gr.Blocks() as demo:
77
- gr.Markdown("# One More Step Demo")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  with gr.Row():
79
- with gr.Column():
80
- prompt = gr.Textbox(label="Prompt", value="a cat against black ground, studio")
81
- oms_prompt = gr.Textbox(label="OMS Prompt", value="an orange cat")
82
- oms_guidance_scale = gr.Slider(label="OMS Guidance Scale", minimum=1.0, maximum=5.0, value=1.5, step=0.1)
83
- run_button = gr.Button(value="Generate images")
84
- with gr.Accordion("Advanced options", open=False):
85
- num_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=4, step=1)
86
- sd_guidance_scale = gr.Slider(label="SD Pipe Guidance Scale", minimum=0.1, maximum=30.0, value=1.0, step=0.1)
87
- seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=False, value=1024)
88
- with gr.Column():
89
- output_raw = gr.Image(label="SDXL w/ LCM-LoRA w/o OMS ")
90
- output_oms = gr.Image(label="w/ OMS w/o OMS CFG")
91
- with gr.Column(visible=False) as oms_cfg_wd:
92
- output_oms_cfg = gr.Image(label="w/ OMS w/ OMS CFG")
93
 
94
- ips = [prompt, oms_prompt, oms_guidance_scale, num_steps, sd_guidance_scale, seed]
95
- run_button.click(fn=self._inference, inputs=ips, outputs=[output_raw, output_oms, output_oms_cfg, oms_cfg_wd])
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  demo.queue(max_size=20)
98
  demo.launch()
 
35
  num_inference_steps = 4,
36
  sd_pipe_guidance_scale = 1.0,
37
  seed = 1024,
38
+ oms_prompt_flag=True,
39
  ):
40
  pipe_kwargs = dict(
41
  prompt = prompt,
 
52
  )['images'][0]
53
 
54
  generator = torch.Generator(device=self.pipe.device).manual_seed(seed)
55
+ pipe_kwargs.update(oms_flag=True, oms_prompt=prompt, oms_guidance_scale=1.0)
56
+ print(f'w/ oms wo/ cfg (consistent) kwargs: {pipe_kwargs}')
57
+ image_oms_cp = self.pipe(
58
  **pipe_kwargs,
59
  generator=generator
60
  )['images'][0]
61
 
62
+ if oms_prompt_flag:
63
+ generator = torch.Generator(device=self.pipe.device).manual_seed(seed)
64
+ pipe_kwargs.update(oms_prompt=oms_prompt)
65
+ print(f'w/ oms wo/ cfg (inconsistent) kwargs: {pipe_kwargs}')
66
+ image_oms_icp = self.pipe(
67
+ **pipe_kwargs,
68
+ generator=generator
69
+ )['images'][0]
70
+ else:
71
+ image_oms_icp = None
72
+
73
  oms_guidance_flag = oms_guidance_scale != 1.0
74
  if oms_guidance_flag:
75
  generator = torch.Generator(device=self.pipe.device).manual_seed(seed)
76
  pipe_kwargs.update(oms_guidance_scale=oms_guidance_scale)
77
+ print(f'w/ oms +cfg (inconsistent) kwargs: {pipe_kwargs}')
78
  image_oms_cfg = self.pipe(
79
  **pipe_kwargs,
80
  generator=generator
 
82
  else:
83
  image_oms_cfg = None
84
 
85
+ return image_raw, image_oms_cp, image_oms_icp, image_oms_cfg, gr.update(visible=oms_prompt_flag), gr.update(visible=oms_guidance_flag)
86
 
87
  def mainloop(self):
88
  with gr.Blocks() as demo:
89
+ gr.Markdown("# One More Step for SDXL w/ LCM-LoRA")
90
+
91
+ with gr.Group() as inputs:
92
+ prompt = gr.Textbox(label="Prompt", value="a cat against orange ground, studio")
93
+ with gr.Accordion('OMS Prompt'):
94
+ oms_prompt_checkbox = gr.Checkbox(info="Inconsistent OMS prompt allows the additional control of low freq info, default is the same as Prompt.", label="Adding OMS Prompt", value=True)
95
+ oms_prompt = gr.Textbox(label="OMS Prompt", value="a black cat", info='try "a black cat" and "a black room" for diverse control.')
96
+ with gr.Accordion('OMS Guidance'):
97
+ oms_cfg_scale_checkbox = gr.Checkbox(info="OMS Guidance will enhance the OMS prompt, specially focus on color and brightness. ", label="Adding OMS Guidance", value=True)
98
+ oms_guidance_scale = gr.Slider(label="OMS Guidance Scale", minimum=1.0, maximum=5.0, value=2., step=0.1)
99
+ run_button = gr.Button(value="Generate images")
100
+ with gr.Accordion("Advanced options", open=False):
101
+ num_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=4, step=1)
102
+ sd_guidance_scale = gr.Slider(label="SD Pipe Guidance Scale", minimum=1, maximum=3, value=1.0, step=0.1)
103
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=False, value=1024)
104
  with gr.Row():
105
+ output_raw = gr.Image(label="SDXL w/ LCM-LoRA ")
106
+ output_oms_cp = gr.Image(label="w/ OMS (consistent prompt) w/o OMS CFG")
107
+ output_oms_icp = gr.Image(label="w/ OMS (inconsistent prompt) w/o OMS CFG")
108
+ output_oms_cfg = gr.Image(label="w/ OMS w/ OMS CFG")
 
 
 
 
 
 
 
 
 
 
109
 
110
+ oms_prompt_checkbox.input(
111
+ fn=lambda oms_prompt_flag, prompt, oms_prompt: (oms_prompt if oms_prompt_flag else prompt, gr.update(interactive=oms_prompt_flag)),
112
+ inputs=[oms_prompt_checkbox, prompt, oms_prompt],
113
+ outputs=[oms_prompt, oms_prompt]
114
+ )
115
+ oms_cfg_scale_checkbox.input(
116
+ fn=lambda oms_cfg_scale_flag: (1.5 if oms_cfg_scale_flag else 1.0, gr.update(interactive=oms_cfg_scale_flag)),
117
+ inputs=[oms_cfg_scale_checkbox],
118
+ outputs=[oms_guidance_scale, oms_guidance_scale]
119
+ )
120
+
121
+ ips = [prompt, oms_prompt, oms_guidance_scale, num_steps, sd_guidance_scale, seed, oms_prompt_checkbox]
122
+ run_button.click(fn=self._inference, inputs=ips, outputs=[output_raw, output_oms_cp, output_oms_icp, output_oms_cfg, output_oms_icp, output_oms_cfg])
123
 
124
  demo.queue(max_size=20)
125
  demo.launch()