tools / run_if_stages.py
patrickvonplaten's picture
up
e58dd86
raw
history blame
No virus
2.07 kB
#!/usr/bin/env python3
#!/usr/bin/env python3
from diffusers import DiffusionPipeline
import torch
import time
import os
from pathlib import Path
from huggingface_hub import HfApi
api = HfApi()
start_time = time.time()
model_prefix = "diffusers"
pipe = DiffusionPipeline.from_pretrained(f"{model_prefix}/IF-I-IF-v1.0", torch_dtype=torch.float16, safety_checker=None, variant="fp16", use_safetensors=True)
pipe.enable_model_cpu_offload()
super_res_1_pipe = DiffusionPipeline.from_pretrained(f"{model_prefix}/IF-II-L-v1.0", text_encoder=None, safety_checker=None, torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
super_res_1_pipe.enable_model_cpu_offload()
super_res_2_pipe = DiffusionPipeline.from_pretrained(f"{model_prefix}/IF-III-L-v1.0", text_encoder=None, safety_checker=None, torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
super_res_2_pipe.enable_model_cpu_offload()
prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'
generator = torch.Generator("cuda").manual_seed(0)
prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt").images
# save_image
pil_image = pipe.numpy_to_pil(pipe.decode_latents(image))[0]
pil_image.save(os.path.join(Path.home(), "images", "if_stage_I_0.png"))
image = super_res_1_pipe(image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt", noise_level=250, num_inference_steps=50).images
# save_image
pil_image = pipe.numpy_to_pil(pipe.decode_latents(image))[0]
pil_image.save(os.path.join(Path.home(), "images", "if_stage_II_0.png"))
image = super_res_2_pipe(image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, noise_level=0, num_inference_steps=40, generator=generator).images[0]
# save_image
image.save(os.path.join(Path.home(), "images", "if_stage_III_0.png"))