# Prediction interface for Cog ⚙️ # https://github.com/replicate/cog/blob/main/docs/python.md import os import sys sys.path.append(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(os.path.join(os.path.dirname(__file__), "../gradio_demo")) import cv2 import time import torch import mimetypes import subprocess import numpy as np from typing import List from cog import BasePredictor, Input, Path import PIL from PIL import Image import diffusers from diffusers import LCMScheduler from diffusers.utils import load_image from diffusers.models import ControlNetModel from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from model_util import get_torch_device from insightface.app import FaceAnalysis from transformers import CLIPImageProcessor from controlnet_util import openpose, get_depth_map, get_canny_image from diffusers.pipelines.stable_diffusion.safety_checker import ( StableDiffusionSafetyChecker, ) from pipeline_stable_diffusion_xl_instantid_full import ( StableDiffusionXLInstantIDPipeline, draw_kps, ) mimetypes.add_type("image/webp", ".webp") # GPU global variables DEVICE = get_torch_device() DTYPE = torch.float16 if str(DEVICE).__contains__("cuda") else torch.float32 # for `ip-adapter`, `ControlNetModel`, and `stable-diffusion-xl-base-1.0` CHECKPOINTS_CACHE = "./checkpoints" CHECKPOINTS_URL = "https://weights.replicate.delivery/default/InstantID/checkpoints.tar" # for `models/antelopev2` MODELS_CACHE = "./models" MODELS_URL = "https://weights.replicate.delivery/default/InstantID/models.tar" # for the safety checker SAFETY_CACHE = "./safety-cache" FEATURE_EXTRACTOR = "./feature-extractor" SAFETY_URL = "https://weights.replicate.delivery/default/playgroundai/safety-cache.tar" SDXL_NAME_TO_PATHLIKE = { # These are all huggingface models that we host via gcp + pget "stable-diffusion-xl-base-1.0": { "slug": "stabilityai/stable-diffusion-xl-base-1.0", "url": "https://weights.replicate.delivery/default/InstantID/models--stabilityai--stable-diffusion-xl-base-1.0.tar", "path": "checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0", }, "afrodite-xl-v2": { "slug": "stablediffusionapi/afrodite-xl-v2", "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--afrodite-xl-v2.tar", "path": "checkpoints/models--stablediffusionapi--afrodite-xl-v2", }, "albedobase-xl-20": { "slug": "stablediffusionapi/albedobase-xl-20", "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--albedobase-xl-20.tar", "path": "checkpoints/models--stablediffusionapi--albedobase-xl-20", }, "albedobase-xl-v13": { "slug": "stablediffusionapi/albedobase-xl-v13", "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--albedobase-xl-v13.tar", "path": "checkpoints/models--stablediffusionapi--albedobase-xl-v13", }, "animagine-xl-30": { "slug": "stablediffusionapi/animagine-xl-30", "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--animagine-xl-30.tar", "path": "checkpoints/models--stablediffusionapi--animagine-xl-30", }, "anime-art-diffusion-xl": { "slug": "stablediffusionapi/anime-art-diffusion-xl", "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--anime-art-diffusion-xl.tar", "path": "checkpoints/models--stablediffusionapi--anime-art-diffusion-xl", }, "anime-illust-diffusion-xl": { "slug": "stablediffusionapi/anime-illust-diffusion-xl", "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--anime-illust-diffusion-xl.tar", "path": "checkpoints/models--stablediffusionapi--anime-illust-diffusion-xl", }, "dreamshaper-xl": { "slug": "stablediffusionapi/dreamshaper-xl", "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--dreamshaper-xl.tar", "path": "checkpoints/models--stablediffusionapi--dreamshaper-xl", }, "dynavision-xl-v0610": { "slug": "stablediffusionapi/dynavision-xl-v0610", "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--dynavision-xl-v0610.tar", "path": "checkpoints/models--stablediffusionapi--dynavision-xl-v0610", }, "guofeng4-xl": { "slug": "stablediffusionapi/guofeng4-xl", "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--guofeng4-xl.tar", "path": "checkpoints/models--stablediffusionapi--guofeng4-xl", }, "juggernaut-xl-v8": { "slug": "stablediffusionapi/juggernaut-xl-v8", "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--juggernaut-xl-v8.tar", "path": "checkpoints/models--stablediffusionapi--juggernaut-xl-v8", }, "nightvision-xl-0791": { "slug": "stablediffusionapi/nightvision-xl-0791", "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--nightvision-xl-0791.tar", "path": "checkpoints/models--stablediffusionapi--nightvision-xl-0791", }, "omnigen-xl": { "slug": "stablediffusionapi/omnigen-xl", "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--omnigen-xl.tar", "path": "checkpoints/models--stablediffusionapi--omnigen-xl", }, "pony-diffusion-v6-xl": { "slug": "stablediffusionapi/pony-diffusion-v6-xl", "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--pony-diffusion-v6-xl.tar", "path": "checkpoints/models--stablediffusionapi--pony-diffusion-v6-xl", }, "protovision-xl-high-fidel": { "slug": "stablediffusionapi/protovision-xl-high-fidel", "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--protovision-xl-high-fidel.tar", "path": "checkpoints/models--stablediffusionapi--protovision-xl-high-fidel", }, "RealVisXL_V3.0_Turbo": { "slug": "SG161222/RealVisXL_V3.0_Turbo", "url": "https://weights.replicate.delivery/default/InstantID/models--SG161222--RealVisXL_V3.0_Turbo.tar", "path": "checkpoints/models--SG161222--RealVisXL_V3.0_Turbo", }, "RealVisXL_V4.0_Lightning": { "slug": "SG161222/RealVisXL_V4.0_Lightning", "url": "https://weights.replicate.delivery/default/InstantID/models--SG161222--RealVisXL_V4.0_Lightning.tar", "path": "checkpoints/models--SG161222--RealVisXL_V4.0_Lightning", }, } def convert_from_cv2_to_image(img: np.ndarray) -> Image: return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) def convert_from_image_to_cv2(img: Image) -> np.ndarray: return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) def resize_img( input_image, max_side=1280, min_side=1024, size=None, pad_to_max_side=False, mode=PIL.Image.BILINEAR, base_pixel_number=64, ): w, h = input_image.size if size is not None: w_resize_new, h_resize_new = size else: ratio = min_side / min(h, w) w, h = round(ratio * w), round(ratio * h) ratio = max_side / max(h, w) input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode) w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number input_image = input_image.resize([w_resize_new, h_resize_new], mode) if pad_to_max_side: res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 offset_x = (max_side - w_resize_new) // 2 offset_y = (max_side - h_resize_new) // 2 res[offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new] = ( np.array(input_image) ) input_image = Image.fromarray(res) return input_image def download_weights(url, dest): start = time.time() print("[!] Initiating download from URL: ", url) print("[~] Destination path: ", dest) command = ["pget", "-vf", url, dest] if ".tar" in url: command.append("-x") try: subprocess.check_call(command, close_fds=False) except subprocess.CalledProcessError as e: print( f"[ERROR] Failed to download weights. Command '{' '.join(e.cmd)}' returned non-zero exit status {e.returncode}." ) raise print("[+] Download completed in: ", time.time() - start, "seconds") class Predictor(BasePredictor): def setup(self) -> None: """Load the model into memory to make running multiple predictions efficient""" if not os.path.exists(CHECKPOINTS_CACHE): download_weights(CHECKPOINTS_URL, CHECKPOINTS_CACHE) if not os.path.exists(MODELS_CACHE): download_weights(MODELS_URL, MODELS_CACHE) self.face_detection_input_width, self.face_detection_input_height = 640, 640 self.app = FaceAnalysis( name="antelopev2", root="./", providers=["CUDAExecutionProvider", "CPUExecutionProvider"], ) self.app.prepare(ctx_id=0, det_size=(self.face_detection_input_width, self.face_detection_input_height)) # Path to InstantID models self.face_adapter = f"./checkpoints/ip-adapter.bin" controlnet_path = f"./checkpoints/ControlNetModel" # Load pipeline face ControlNetModel self.controlnet_identitynet = ControlNetModel.from_pretrained( controlnet_path, torch_dtype=DTYPE, cache_dir=CHECKPOINTS_CACHE, local_files_only=True, ) self.setup_extra_controlnets() self.load_weights("stable-diffusion-xl-base-1.0") self.setup_safety_checker() def setup_safety_checker(self): print(f"[~] Seting up safety checker") if not os.path.exists(SAFETY_CACHE): download_weights(SAFETY_URL, SAFETY_CACHE) self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( SAFETY_CACHE, torch_dtype=DTYPE, local_files_only=True, ) self.safety_checker.to(DEVICE) self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR) def run_safety_checker(self, image): safety_checker_input = self.feature_extractor(image, return_tensors="pt").to( DEVICE ) np_image = np.array(image) image, has_nsfw_concept = self.safety_checker( images=[np_image], clip_input=safety_checker_input.pixel_values.to(DTYPE), ) return image, has_nsfw_concept def load_weights(self, sdxl_weights): self.base_weights = sdxl_weights weights_info = SDXL_NAME_TO_PATHLIKE[self.base_weights] download_url = weights_info["url"] path_to_weights_dir = weights_info["path"] if not os.path.exists(path_to_weights_dir): download_weights(download_url, path_to_weights_dir) is_hugging_face_model = "slug" in weights_info.keys() path_to_weights_file = os.path.join( path_to_weights_dir, weights_info.get("file", ""), ) print(f"[~] Loading new SDXL weights: {path_to_weights_file}") if is_hugging_face_model: self.pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( weights_info["slug"], controlnet=[self.controlnet_identitynet], torch_dtype=DTYPE, cache_dir=CHECKPOINTS_CACHE, local_files_only=True, safety_checker=None, feature_extractor=None, ) self.pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config( self.pipe.scheduler.config ) else: # e.g. .safetensors, NOTE: This functionality is not being used right now self.pipe.from_single_file( path_to_weights_file, controlnet=self.controlnet_identitynet, torch_dtype=DTYPE, cache_dir=CHECKPOINTS_CACHE, ) self.pipe.load_ip_adapter_instantid(self.face_adapter) self.setup_lcm_lora() self.pipe.cuda() def setup_lcm_lora(self): print(f"[~] Seting up LCM (just in case)") lcm_lora_key = "models--latent-consistency--lcm-lora-sdxl" lcm_lora_path = f"checkpoints/{lcm_lora_key}" if not os.path.exists(lcm_lora_path): download_weights( f"https://weights.replicate.delivery/default/InstantID/{lcm_lora_key}.tar", lcm_lora_path, ) self.pipe.load_lora_weights( "latent-consistency/lcm-lora-sdxl", cache_dir=CHECKPOINTS_CACHE, local_files_only=True, weight_name="pytorch_lora_weights.safetensors", ) self.pipe.disable_lora() def setup_extra_controlnets(self): print(f"[~] Seting up pose, canny, depth ControlNets") controlnet_pose_model = "thibaud/controlnet-openpose-sdxl-1.0" controlnet_canny_model = "diffusers/controlnet-canny-sdxl-1.0" controlnet_depth_model = "diffusers/controlnet-depth-sdxl-1.0-small" for controlnet_key in [ "models--diffusers--controlnet-canny-sdxl-1.0", "models--diffusers--controlnet-depth-sdxl-1.0-small", "models--thibaud--controlnet-openpose-sdxl-1.0", ]: controlnet_path = f"checkpoints/{controlnet_key}" if not os.path.exists(controlnet_path): download_weights( f"https://weights.replicate.delivery/default/InstantID/{controlnet_key}.tar", controlnet_path, ) controlnet_pose = ControlNetModel.from_pretrained( controlnet_pose_model, torch_dtype=DTYPE, cache_dir=CHECKPOINTS_CACHE, local_files_only=True, ).to(DEVICE) controlnet_canny = ControlNetModel.from_pretrained( controlnet_canny_model, torch_dtype=DTYPE, cache_dir=CHECKPOINTS_CACHE, local_files_only=True, ).to(DEVICE) controlnet_depth = ControlNetModel.from_pretrained( controlnet_depth_model, torch_dtype=DTYPE, cache_dir=CHECKPOINTS_CACHE, local_files_only=True, ).to(DEVICE) self.controlnet_map = { "pose": controlnet_pose, "canny": controlnet_canny, "depth": controlnet_depth, } self.controlnet_map_fn = { "pose": openpose, "canny": get_canny_image, "depth": get_depth_map, } def generate_image( self, face_image_path, pose_image_path, prompt, negative_prompt, num_steps, identitynet_strength_ratio, adapter_strength_ratio, pose_strength, canny_strength, depth_strength, controlnet_selection, guidance_scale, seed, scheduler, enable_LCM, enhance_face_region, num_images_per_prompt, ): if enable_LCM: self.pipe.enable_lora() self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config) else: self.pipe.disable_lora() scheduler_class_name = scheduler.split("-")[0] add_kwargs = {} if len(scheduler.split("-")) > 1: add_kwargs["use_karras_sigmas"] = True if len(scheduler.split("-")) > 2: add_kwargs["algorithm_type"] = "sde-dpmsolver++" scheduler = getattr(diffusers, scheduler_class_name) self.pipe.scheduler = scheduler.from_config( self.pipe.scheduler.config, **add_kwargs, ) if face_image_path is None: raise Exception( f"Cannot find any input face `image`! Please upload the face `image`" ) face_image = load_image(face_image_path) face_image = resize_img(face_image) face_image_cv2 = convert_from_image_to_cv2(face_image) height, width, _ = face_image_cv2.shape # Extract face features face_info = self.app.get(face_image_cv2) if len(face_info) == 0: raise Exception( "Face detector could not find a face in the `image`. Please use a different `image` as input." ) face_info = sorted( face_info, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1], )[ -1 ] # only use the maximum face face_emb = face_info["embedding"] face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"]) img_controlnet = face_image if pose_image_path is not None: pose_image = load_image(pose_image_path) pose_image = resize_img(pose_image, max_side=1024) img_controlnet = pose_image pose_image_cv2 = convert_from_image_to_cv2(pose_image) face_info = self.app.get(pose_image_cv2) if len(face_info) == 0: raise Exception( "Face detector could not find a face in the `pose_image`. Please use a different `pose_image` as input." ) face_info = face_info[-1] face_kps = draw_kps(pose_image, face_info["kps"]) width, height = face_kps.size if enhance_face_region: control_mask = np.zeros([height, width, 3]) x1, y1, x2, y2 = face_info["bbox"] x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) control_mask[y1:y2, x1:x2] = 255 control_mask = Image.fromarray(control_mask.astype(np.uint8)) else: control_mask = None if len(controlnet_selection) > 0: controlnet_scales = { "pose": pose_strength, "canny": canny_strength, "depth": depth_strength, } self.pipe.controlnet = MultiControlNetModel( [self.controlnet_identitynet] + [self.controlnet_map[s] for s in controlnet_selection] ) control_scales = [float(identitynet_strength_ratio)] + [ controlnet_scales[s] for s in controlnet_selection ] control_images = [face_kps] + [ self.controlnet_map_fn[s](img_controlnet).resize((width, height)) for s in controlnet_selection ] else: self.pipe.controlnet = self.controlnet_identitynet control_scales = float(identitynet_strength_ratio) control_images = face_kps generator = torch.Generator(device=DEVICE).manual_seed(seed) print("Start inference...") print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}") self.pipe.set_ip_adapter_scale(adapter_strength_ratio) images = self.pipe( prompt=prompt, negative_prompt=negative_prompt, image_embeds=face_emb, image=control_images, control_mask=control_mask, controlnet_conditioning_scale=control_scales, num_inference_steps=num_steps, guidance_scale=guidance_scale, height=height, width=width, generator=generator, num_images_per_prompt=num_images_per_prompt, ).images return images def predict( self, image: Path = Input( description="Input face image", ), pose_image: Path = Input( description="(Optional) reference pose image", default=None, ), prompt: str = Input( description="Input prompt", default="a person", ), negative_prompt: str = Input( description="Input Negative Prompt", default="", ), sdxl_weights: str = Input( description="Pick which base weights you want to use", default="stable-diffusion-xl-base-1.0", choices=[ "stable-diffusion-xl-base-1.0", "juggernaut-xl-v8", "afrodite-xl-v2", "albedobase-xl-20", "albedobase-xl-v13", "animagine-xl-30", "anime-art-diffusion-xl", "anime-illust-diffusion-xl", "dreamshaper-xl", "dynavision-xl-v0610", "guofeng4-xl", "nightvision-xl-0791", "omnigen-xl", "pony-diffusion-v6-xl", "protovision-xl-high-fidel", "RealVisXL_V3.0_Turbo", "RealVisXL_V4.0_Lightning", ], ), face_detection_input_width: int = Input( description="Width of the input image for face detection", default=640, ge=640, le=4096, ), face_detection_input_height: int = Input( description="Height of the input image for face detection", default=640, ge=640, le=4096, ), scheduler: str = Input( description="Scheduler", choices=[ "DEISMultistepScheduler", "HeunDiscreteScheduler", "EulerDiscreteScheduler", "DPMSolverMultistepScheduler", "DPMSolverMultistepScheduler-Karras", "DPMSolverMultistepScheduler-Karras-SDE", ], default="EulerDiscreteScheduler", ), num_inference_steps: int = Input( description="Number of denoising steps", default=30, ge=1, le=500, ), guidance_scale: float = Input( description="Scale for classifier-free guidance", default=7.5, ge=1, le=50, ), ip_adapter_scale: float = Input( description="Scale for image adapter strength (for detail)", # adapter_strength_ratio default=0.8, ge=0, le=1.5, ), controlnet_conditioning_scale: float = Input( description="Scale for IdentityNet strength (for fidelity)", # identitynet_strength_ratio default=0.8, ge=0, le=1.5, ), enable_pose_controlnet: bool = Input( description="Enable Openpose ControlNet, overrides strength if set to false", default=True, ), pose_strength: float = Input( description="Openpose ControlNet strength, effective only if `enable_pose_controlnet` is true", default=0.4, ge=0, le=1, ), enable_canny_controlnet: bool = Input( description="Enable Canny ControlNet, overrides strength if set to false", default=False, ), canny_strength: float = Input( description="Canny ControlNet strength, effective only if `enable_canny_controlnet` is true", default=0.3, ge=0, le=1, ), enable_depth_controlnet: bool = Input( description="Enable Depth ControlNet, overrides strength if set to false", default=False, ), depth_strength: float = Input( description="Depth ControlNet strength, effective only if `enable_depth_controlnet` is true", default=0.5, ge=0, le=1, ), enable_lcm: bool = Input( description="Enable Fast Inference with LCM (Latent Consistency Models) - speeds up inference steps, trade-off is the quality of the generated image. Performs better with close-up portrait face images", default=False, ), lcm_num_inference_steps: int = Input( description="Only used when `enable_lcm` is set to True, Number of denoising steps when using LCM", default=5, ge=1, le=10, ), lcm_guidance_scale: float = Input( description="Only used when `enable_lcm` is set to True, Scale for classifier-free guidance when using LCM", default=1.5, ge=1, le=20, ), enhance_nonface_region: bool = Input( description="Enhance non-face region", default=True ), output_format: str = Input( description="Format of the output images", choices=["webp", "jpg", "png"], default="webp", ), output_quality: int = Input( description="Quality of the output images, from 0 to 100. 100 is best quality, 0 is lowest quality.", default=80, ge=0, le=100, ), seed: int = Input( description="Random seed. Leave blank to randomize the seed", default=None, ), num_outputs: int = Input( description="Number of images to output", default=1, ge=1, le=8, ), disable_safety_checker: bool = Input( description="Disable safety checker for generated images", default=False, ), ) -> List[Path]: """Run a single prediction on the model""" # If no seed is provided, generate a random seed if seed is None: seed = int.from_bytes(os.urandom(2), "big") print(f"Using seed: {seed}") # Load the weights if they are different from the base weights if sdxl_weights != self.base_weights: self.load_weights(sdxl_weights) # Resize the output if the provided dimensions are different from the current ones if self.face_detection_input_width != face_detection_input_width or self.face_detection_input_height != face_detection_input_height: print(f"[!] Resizing output to {face_detection_input_width}x{face_detection_input_height}") self.face_detection_input_width = face_detection_input_width self.face_detection_input_height = face_detection_input_height self.app.prepare(ctx_id=0, det_size=(self.face_detection_input_width, self.face_detection_input_height)) # Set up ControlNet selection and their respective strength values (if any) controlnet_selection = [] if pose_strength > 0 and enable_pose_controlnet: controlnet_selection.append("pose") if canny_strength > 0 and enable_canny_controlnet: controlnet_selection.append("canny") if depth_strength > 0 and enable_depth_controlnet: controlnet_selection.append("depth") # Switch to LCM inference steps and guidance scale if LCM is enabled if enable_lcm: num_inference_steps = lcm_num_inference_steps guidance_scale = lcm_guidance_scale # Generate images = self.generate_image( face_image_path=str(image), pose_image_path=str(pose_image) if pose_image else None, prompt=prompt, negative_prompt=negative_prompt, num_steps=num_inference_steps, identitynet_strength_ratio=controlnet_conditioning_scale, adapter_strength_ratio=ip_adapter_scale, pose_strength=pose_strength, canny_strength=canny_strength, depth_strength=depth_strength, controlnet_selection=controlnet_selection, scheduler=scheduler, guidance_scale=guidance_scale, seed=seed, enable_LCM=enable_lcm, enhance_face_region=enhance_nonface_region, num_images_per_prompt=num_outputs, ) # Save the generated images and check for NSFW content output_paths = [] for i, output_image in enumerate(images): if not disable_safety_checker: _, has_nsfw_content_list = self.run_safety_checker(output_image) has_nsfw_content = any(has_nsfw_content_list) print(f"NSFW content detected: {has_nsfw_content}") if has_nsfw_content: raise Exception( "NSFW content detected. Try running it again, or try a different prompt." ) extension = output_format.lower() extension = "jpeg" if extension == "jpg" else extension output_path = f"/tmp/out_{i}.{extension}" print(f"[~] Saving to {output_path}...") print(f"[~] Output format: {extension.upper()}") if output_format != "png": print(f"[~] Output quality: {output_quality}") save_params = {"format": extension.upper()} if output_format != "png": save_params["quality"] = output_quality save_params["optimize"] = True output_image.save(output_path, **save_params) output_paths.append(Path(output_path)) return output_paths