#!/usr/bin/env python from __future__ import annotations import argparse import functools import os import pathlib import sys from typing import Callable import uuid import cv2 sys.path.insert(0, 'APDrawingGAN2') import gradio as gr import huggingface_hub import numpy as np import PIL.Image from io import BytesIO import shutil from options.test_options import TestOptions from data import CreateDataLoader from models import create_model import dlib from preprocess.get_partmask import get_68lm, get_partmask from util import html import ntpath from util import util ORIGINAL_REPO_URL = 'https://github.com/yiranran/APDrawingGAN2' TITLE = 'yiranran/APDrawingGAN2' DESCRIPTION = f"""This is a demo for {ORIGINAL_REPO_URL}. """ ARTICLE = """ """ MODEL_REPO = 'hylee/apdrawing_model' def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cpu') parser.add_argument('--theme', type=str) parser.add_argument('--live', action='store_true') parser.add_argument('--share', action='store_true') parser.add_argument('--port', type=int) parser.add_argument('--disable-queue', dest='enable_queue', action='store_false') parser.add_argument('--allow-flagging', type=str, default='never') parser.add_argument('--allow-screenshot', action='store_true') return parser.parse_args() def load_checkpoint(): dir = 'checkpoint' checkpoint_path = huggingface_hub.hf_hub_download(MODEL_REPO, 'checkpoints.zip', force_filename='checkpoints.zip') print(checkpoint_path) shutil.unpack_archive(checkpoint_path, extract_dir=dir) print(os.listdir(dir + '/checkpoints')) dataset_path = huggingface_hub.hf_hub_download(MODEL_REPO, 'dataset.zip', force_filename='dataset.zip') print(checkpoint_path) shutil.unpack_archive(dataset_path) return dir + '/checkpoints' # save image to the disk def save_images2(image_dir, visuals, image_path, aspect_ratio=1.0, width=256): short_path = ntpath.basename(image_path[0]) name = os.path.splitext(short_path)[0] imgs = [] for label, im_data in visuals.items(): im = util.tensor2im(im_data) # tensor to numpy array [-1,1]->[0,1]->[0,255] image_name = '%s_%s.png' % (name, label) save_path = os.path.join(image_dir, image_name) h, w, _ = im.shape if aspect_ratio > 1.0: im = np.array(PIL.Image.fromarray(arr).resize(im, (h, int(w * aspect_ratio)))) if aspect_ratio < 1.0: im = np.array(PIL.Image.fromarray(arr).resize(im, (int(h / aspect_ratio), w))) util.save_image(im, save_path) imgs.append(save_path) return imgs SAFEHASH = [x for x in "0123456789-abcdefghijklmnopqrstuvwxyz_ABCDEFGHIJKLMNOPQRSTUVWXYZ"] def compress_UUID(): ''' 根据http://www.ietf.org/rfc/rfc1738.txt,由uuid编码扩bai大字符域生成du串 包括:[0-9a-zA-Z\-_]共64个 长度:(32-2)/3*2=20 备注:可在地球上人zhi人都用,使用100年不重复(2^120) :return:String ''' row = str(uuid.uuid4()).replace('-', '') safe_code = '' for i in range(10): enbin = "%012d" % int(bin(int(row[i * 3] + row[i * 3 + 1] + row[i * 3 + 2], 16))[2:], 10) safe_code += (SAFEHASH[int(enbin[0:6], 2)] + SAFEHASH[int(enbin[6:12], 2)]) safe_code = safe_code.replace('-', '') return safe_code def run( image, mask, model, opt, detector, predictor, ) -> tuple[PIL.Image.Image,PIL.Image.Image]: dataroot = 'images/' + compress_UUID() opt.dataroot = os.path.join(dataroot, 'src/') os.makedirs(opt.dataroot, exist_ok=True) opt.results_dir = os.path.join(dataroot, 'results/') os.makedirs(opt.results_dir, exist_ok=True) opt.lm_dir = os.path.join(dataroot, 'landmark/') opt.cmask_dir = os.path.join(dataroot, 'mask/') opt.bg_dir = os.path.join(dataroot, 'mask/bg') os.makedirs(opt.lm_dir, exist_ok=True) os.makedirs(opt.cmask_dir, exist_ok=True) os.makedirs(opt.bg_dir, exist_ok=True) shutil.copy(image.name, opt.dataroot) fullname = os.path.basename(image.name) name = fullname.split(".")[0] bg = cv2.cvtColor(cv2.imread(mask.name), cv2.COLOR_BGR2GRAY) cv2.imwrite(os.path.join(opt.bg_dir, name+'.png'), bg) #shutil.copyfile(mask.name, os.path.join(opt.bg_dir, name+'.png')) imgfile = os.path.join(opt.dataroot, fullname) lmfile5 = os.path.join(opt.lm_dir, name+'.txt') lmfile68 = os.path.join(opt.lm_dir, name + '_68.txt') # 预处理数据 get_68lm(imgfile, lmfile5, lmfile68, detector, predictor) imgs = [] for part in ['eyel', 'eyer', 'nose', 'mouth']: savepath = os.path.join(opt.cmask_dir + part, name+'.png') get_partmask(imgfile, part, lmfile68, savepath) #imgs.append(savepath) # ''' # python test.py --dataroot dataset/test_single --name apdrawinggan++_author --model test --use_resnet --netG resnet_9blocks --which_epoch 150 --how_many 1000 --gpu_ids 0 --gpu_ids_p 0 --imagefolder images-single # ''' # opt.dataroot='dataset/test_single' # opt.name = 'apdrawinggan++_author' # opt.model = 'test' # opt.use_resnet = True # opt.netG = 'resnet_9blocks' # opt.which_epoch = 150 # opt.how_many = 1000 # opt.gpu_ids = '-1' # opt.imagefolder = 'images-single' data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() # test # model.eval() for i, data in enumerate(dataset): if i >= opt.how_many: # test code only supports batch_size = 1, how_many means how many test images to run break model.set_input(data) model.test() visuals = model.get_current_visuals() # in test the loadSize is set to the same as fineSize img_path = model.get_image_paths() # if i % 5 == 0: # print('processing (%04d)-th image... %s' % (i, img_path)) imgs = save_images2(opt.results_dir, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) print(imgs) return PIL.Image.open(imgs[0]),PIL.Image.open(imgs[1]) def main(): gr.close_all() args = parse_args() checkpoint_dir = load_checkpoint() opt = TestOptions().parse() opt.num_threads = 1 # test code only supports num_threads = 1 opt.batch_size = 1 # test code only supports batch_size = 1 opt.serial_batches = True # no shuffle opt.no_flip = True # no flip opt.display_id = -1 # no visdom display opt.checkpoints_dir = checkpoint_dir model = create_model(opt) model.setup(opt) ''' 预处理数据 ''' detector = dlib.get_frontal_face_detector() predictor = dlib.shape_predictor(checkpoint_dir + '/shape_predictor_68_face_landmarks.dat') func = functools.partial(run, model=model, opt=opt, detector=detector, predictor=predictor) func = functools.update_wrapper(func, run) gr.Interface( func, [ gr.inputs.Image(type='file', label='Input Image'), gr.inputs.Image(type='file', label='Input Mask'), ], [ gr.outputs.Image( type='pil', label='Result'), gr.outputs.Image( type='pil', label='Result'), ], # examples=examples, theme=args.theme, title=TITLE, description=DESCRIPTION, article=ARTICLE, allow_screenshot=args.allow_screenshot, allow_flagging=args.allow_flagging, live=args.live, ).launch( enable_queue=args.enable_queue, server_port=args.port, share=args.share, ) if __name__ == '__main__': main()