Test1 / src /models /common.py
AndranikSargsyan
add support for diffusers checkpoint loading
f1cc496
raw
history blame
No virus
5.55 kB
import importlib
import requests
from collections import OrderedDict
from pathlib import Path
from os.path import dirname
import torch
import safetensors
import safetensors.torch
from omegaconf import OmegaConf
from tqdm import tqdm
from src.smplfusion import DDIM, share, scheduler
from src.utils.convert_diffusers_to_sd import (
convert_vae_state_dict,
convert_unet_state_dict,
convert_text_enc_state_dict,
convert_text_enc_state_dict_v20
)
PROJECT_DIR = dirname(dirname(dirname(__file__)))
CONFIG_FOLDER = f'{PROJECT_DIR}/config'
MODEL_FOLDER = f'{PROJECT_DIR}/checkpoints'
def download_file(url, save_path, chunk_size=1024):
try:
save_path = Path(save_path)
if save_path.exists():
print(f'{save_path.name} exists')
return
save_path.parent.mkdir(exist_ok=True, parents=True)
resp = requests.get(url, stream=True)
total = int(resp.headers.get('content-length', 0))
with open(save_path, 'wb') as file, tqdm(
desc=save_path.name,
total=total,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in resp.iter_content(chunk_size=chunk_size):
size = file.write(data)
bar.update(size)
print(f'{save_path.name} download finished')
except Exception as e:
raise Exception(f"Download failed: {e}")
def get_obj_from_str(string):
module, cls = string.rsplit(".", 1)
try:
return getattr(importlib.import_module(module, package=None), cls)
except:
return getattr(importlib.import_module('src.' + module, package=None), cls)
def load_obj(path):
objyaml = OmegaConf.load(path)
return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {}))
def load_state_dict(model_path):
model_ext = Path(model_path).suffix
if model_ext == '.safetensors':
state_dict = safetensors.torch.load_file(model_path)
elif model_ext == '.ckpt':
state_dict = torch.load(model_path)['state_dict']
elif model_ext == '.bin':
state_dict = torch.load(model_path)
else:
raise Exception(f'Unsupported model extension {model_ext}')
return state_dict
def load_sd_inpainting_model(
download_url,
model_path,
sd_version,
diffusers_ckpt=False,
dtype=torch.float16,
device='cuda:0'
):
if type(download_url) == str and type(model_path) == str:
model_path = f'{MODEL_FOLDER}/{model_path}'
download_file(download_url, model_path)
state_dict = load_state_dict(model_path)
if diffusers_ckpt:
raise Exception('Not implemented')
extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
unet_state = extract(state_dict, 'model.diffusion_model')
encoder_state = extract(state_dict, 'cond_stage_model')
vae_state = extract(state_dict, 'first_stage_model')
elif type(download_url) == OrderedDict and type(model_path) == OrderedDict:
for key in download_url.keys():
download_file(download_url[key], f'{MODEL_FOLDER}/{model_path[key]}')
unet_state = load_state_dict(f'{MODEL_FOLDER}/{model_path["unet"]}')
encoder_state = load_state_dict(f'{MODEL_FOLDER}/{model_path["encoder"]}')
vae_state = load_state_dict(f'{MODEL_FOLDER}/{model_path["vae"]}')
if diffusers_ckpt:
unet_state = convert_unet_state_dict(unet_state)
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in encoder_state
if is_v20_model:
encoder_state = {"transformer." + k: v for k, v in encoder_state .items()}
encoder_state = convert_text_enc_state_dict_v20(encoder_state)
encoder_state = {"model." + k: v for k, v in encoder_state .items()}
else:
encoder_state = convert_text_enc_state_dict(encoder_state)
encoder_state = {"transformer." + k: v for k, v in encoder_state .items()}
vae_state = convert_vae_state_dict(vae_state)
else:
raise Exception('download_url or model_path definition type is not supported')
# Load common config files
config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml')
vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda()
# Load version specific config files
if sd_version == 1:
encoder = load_obj(f'{CONFIG_FOLDER}/encoders/clip.yaml').eval().cuda()
unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v1.yaml').eval().cuda()
elif sd_version == 2:
encoder = load_obj(f'{CONFIG_FOLDER}/encoders/openclip.yaml').eval().cuda()
unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v2.yaml').eval().cuda()
else:
raise Exception(f'Unsupported SD version {sd_version}.')
ddim = DDIM(config, vae, encoder, unet)
unet.load_state_dict(unet_state)
encoder.load_state_dict(encoder_state, strict=False)
vae.load_state_dict(vae_state)
if dtype == torch.float16:
unet.convert_to_fp16()
unet.to(device=device)
vae.to(dtype=dtype, device=device)
encoder.to(dtype=dtype, device=device)
encoder.device = device
unet = unet.requires_grad_(False)
encoder = encoder.requires_grad_(False)
vae = vae.requires_grad_(False)
ddim = DDIM(config, vae, encoder, unet)
share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end)
return ddim