Lawrence-cj commited on
Commit
accee48
1 Parent(s): e0e8c81
Files changed (1) hide show
  1. app.py +12 -43
app.py CHANGED
@@ -103,28 +103,9 @@ def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str
103
  return p.replace("{prompt}", positive), n + negative
104
 
105
 
106
- def get_args():
107
- parser = argparse.ArgumentParser()
108
- parser.add_argument('--is_lora', action='store_true', help='enable lora ckpt loading')
109
- parser.add_argument('--repo_id', default="PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", type=str)
110
- parser.add_argument('--lora_repo_id', default=None, type=str)
111
- parser.add_argument('--model_path', default=None, type=str)
112
- parser.add_argument(
113
- '--pipeline_load_from', default="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", type=str,
114
- help="Download for loading text_encoder, tokenizer and vae "
115
- "from https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS")
116
- parser.add_argument('--T5_token_max_length', default=120, type=int, help='max length of tokens for T5')
117
- return parser.parse_args()
118
-
119
-
120
- args = get_args()
121
-
122
  if torch.cuda.is_available():
123
  weight_dtype = torch.float16
124
- T5_token_max_length = args.T5_token_max_length
125
- model_path = args.model_path
126
- if 'Sigma' in args.model_path:
127
- T5_token_max_length = 300
128
 
129
  # tmp patches for diffusers PixArtSigmaPipeline Implementation
130
  print(
@@ -132,29 +113,17 @@ if torch.cuda.is_available():
132
  "using scripts.diffusers_patches.pixart_sigma_init_patched_inputs")
133
  setattr(Transformer2DModel, '_init_patched_inputs', pixart_sigma_init_patched_inputs)
134
 
135
- if not args.is_lora:
136
- transformer = Transformer2DModel.from_pretrained(
137
- model_path,
138
- subfolder='transformer',
139
- torch_dtype=weight_dtype,
140
- )
141
- pipe = PixArtSigmaPipeline.from_pretrained(
142
- args.pipeline_load_from,
143
- transformer=transformer,
144
- torch_dtype=weight_dtype,
145
- use_safetensors=True,
146
- )
147
- else:
148
- assert args.lora_repo_id is not None
149
- transformer = Transformer2DModel.from_pretrained(args.repo_id, subfolder="transformer", torch_dtype=torch.float16)
150
- transformer = PeftModel.from_pretrained(transformer, args.lora_repo_id)
151
- pipe = PixArtSigmaPipeline.from_pretrained(
152
- args.repo_id,
153
- transformer=transformer,
154
- torch_dtype=torch.float16,
155
- use_safetensors=True,
156
- )
157
- del transformer
158
 
159
  if os.getenv('CONSISTENCY_DECODER', False):
160
  print("Using DALL-E 3 Consistency Decoder")
 
103
  return p.replace("{prompt}", positive), n + negative
104
 
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  if torch.cuda.is_available():
107
  weight_dtype = torch.float16
108
+ T5_token_max_length = 300
 
 
 
109
 
110
  # tmp patches for diffusers PixArtSigmaPipeline Implementation
111
  print(
 
113
  "using scripts.diffusers_patches.pixart_sigma_init_patched_inputs")
114
  setattr(Transformer2DModel, '_init_patched_inputs', pixart_sigma_init_patched_inputs)
115
 
116
+ transformer = Transformer2DModel.from_pretrained(
117
+ "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
118
+ subfolder='transformer',
119
+ torch_dtype=weight_dtype,
120
+ )
121
+ pipe = PixArtSigmaPipeline.from_pretrained(
122
+ "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
123
+ transformer=transformer,
124
+ torch_dtype=weight_dtype,
125
+ use_safetensors=True,
126
+ )
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  if os.getenv('CONSISTENCY_DECODER', False):
129
  print("Using DALL-E 3 Consistency Decoder")