#!/usr/bin/env python from __future__ import annotations import argparse import functools import os import pathlib import sys from typing import Callable import uuid import cv2 sys.path.insert(0, 'APDrawingGAN2') import gradio as gr import huggingface_hub import numpy as np import PIL.Image from io import BytesIO import shutil from options.test_options import TestOptions from data import CreateDataLoader from models import create_model import dlib from preprocess.get_partmask import get_68lm, get_partmask from util import html import ntpath from util import util from modnet import ModNet ORIGINAL_REPO_URL = 'https://github.com/yiranran/APDrawingGAN2' TITLE = 'yiranran/APDrawingGAN2' DESCRIPTION = f"""This is a demo for {ORIGINAL_REPO_URL}. """ ARTICLE = """ """ MODEL_REPO = 'hylee/apdrawing_model' def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cpu') parser.add_argument('--theme', type=str) parser.add_argument('--live', action='store_true') parser.add_argument('--share', action='store_true') parser.add_argument('--port', type=int) parser.add_argument('--disable-queue', dest='enable_queue', action='store_false') parser.add_argument('--allow-flagging', type=str, default='never') parser.add_argument('--allow-screenshot', action='store_true') return parser.parse_args() def load_checkpoint(): dir = 'checkpoint' checkpoint_path = huggingface_hub.hf_hub_download(MODEL_REPO, 'checkpoints.zip', force_filename='checkpoints.zip') print(checkpoint_path) shutil.unpack_archive(checkpoint_path, extract_dir=dir) print(os.listdir(dir + '/checkpoints')) return dir + '/checkpoints' def load_modnet_model(): modnet_path = huggingface_hub.hf_hub_download(MODEL_REPO, 'modnet.onnx', force_filename='modnet.onnx') return modnet_path # save image to the disk def save_images2(image_dir, visuals, image_path, aspect_ratio=1.0, width=256): short_path = ntpath.basename(image_path[0]) name = os.path.splitext(short_path)[0] imgs = [] for label, im_data in visuals.items(): im = util.tensor2im(im_data) # tensor to numpy array [-1,1]->[0,1]->[0,255] image_name = '%s_%s.png' % (name, label) save_path = os.path.join(image_dir, image_name) h, w, _ = im.shape if aspect_ratio > 1.0: im = np.array(PIL.Image.fromarray(arr).resize(im, (h, int(w * aspect_ratio)))) if aspect_ratio < 1.0: im = np.array(PIL.Image.fromarray(arr).resize(im, (int(h / aspect_ratio), w))) util.save_image(im, save_path) imgs.append(save_path) return imgs SAFEHASH = [x for x in "0123456789-abcdefghijklmnopqrstuvwxyz_ABCDEFGHIJKLMNOPQRSTUVWXYZ"] def compress_UUID(): ''' 根据http://www.ietf.org/rfc/rfc1738.txt,由uuid编码扩bai大字符域生成du串 包括:[0-9a-zA-Z\-_]共64个 长度:(32-2)/3*2=20 备注:可在地球上人zhi人都用,使用100年不重复(2^120) :return:String ''' row = str(uuid.uuid4()).replace('-', '') safe_code = '' for i in range(10): enbin = "%012d" % int(bin(int(row[i * 3] + row[i * 3 + 1] + row[i * 3 + 2], 16))[2:], 10) safe_code += (SAFEHASH[int(enbin[0:6], 2)] + SAFEHASH[int(enbin[6:12], 2)]) safe_code = safe_code.replace('-', '') return safe_code def run( image, landmarks_str : str, model, opt, detector, predictor, modnet : ModNet, ) -> tuple[PIL.Image.Image]: dataroot = 'images/' + compress_UUID() opt.dataroot = os.path.join(dataroot, 'src/') os.makedirs(opt.dataroot, exist_ok=True) opt.results_dir = os.path.join(dataroot, 'results/') os.makedirs(opt.results_dir, exist_ok=True) opt.lm_dir = os.path.join(dataroot, 'landmark/') opt.cmask_dir = os.path.join(dataroot, 'mask/') opt.bg_dir = os.path.join(dataroot, 'mask/bg') os.makedirs(opt.lm_dir, exist_ok=True) os.makedirs(opt.cmask_dir, exist_ok=True) os.makedirs(opt.bg_dir, exist_ok=True) shutil.copy(image.name, opt.dataroot) fullname = os.path.basename(image.name) name = fullname.split(".")[0] #bg = cv2.cvtColor(cv2.imread(mask.name), cv2.COLOR_BGR2GRAY) #cv2.imwrite(os.path.join(opt.bg_dir, name+'.png'), bg) modnet.segment(image.name, os.path.join(opt.bg_dir, name+'.png')) imgfile = os.path.join(opt.dataroot, fullname) lmfile5 = os.path.join(opt.lm_dir, name+'.txt') lmfile68 = os.path.join(opt.lm_dir, name + '_68.txt') # 预处理数据 # get_68lm(imgfile, lmfile5, lmfile68, detector, predictor) f = open(lmfile68, 'w') print(landmarks_str, file=f) f.close() landmarks = np.loadtxt(lmfile68, dtype=np.int64) ff = open(lmfile5, 'w') lm = (landmarks[36] + landmarks[39]) / 2 print(int(lm[0]), int(lm[1]), file=ff) lm = (landmarks[45] + landmarks[42]) / 2 print(int(lm[0]), int(lm[1]), file=ff) lm = landmarks[30] print(lm[0], lm[1], file=ff) lm = landmarks[48] print(lm[0], lm[1], file=ff) lm = landmarks[54] print(lm[0], lm[1], file=ff) ff.close() imgs = [] for part in ['eyel', 'eyer', 'nose', 'mouth']: savepath = os.path.join(opt.cmask_dir + part, name+'.png') get_partmask(imgfile, part, lmfile68, savepath) #imgs.append(savepath) data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() # test # model.eval() for i, data in enumerate(dataset): if i >= opt.how_many: # test code only supports batch_size = 1, how_many means how many test images to run break model.set_input(data) model.test() visuals = model.get_current_visuals() # in test the loadSize is set to the same as fineSize img_path = model.get_image_paths() # if i % 5 == 0: # print('processing (%04d)-th image... %s' % (i, img_path)) imgs = save_images2(opt.results_dir, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) print(imgs) return PIL.Image.open(imgs[1]) def main(): gr.close_all() args = parse_args() checkpoint_dir = load_checkpoint() opt = TestOptions().parse() opt.num_threads = 1 # test code only supports num_threads = 1 opt.batch_size = 1 # test code only supports batch_size = 1 opt.serial_batches = True # no shuffle opt.no_flip = True # no flip opt.display_id = -1 # no visdom display opt.checkpoints_dir = checkpoint_dir model = create_model(opt) model.setup(opt) modnet_path = load_modnet_model(); modnet = ModNet(modnet_path) ''' 预处理数据 ''' detector = dlib.get_frontal_face_detector() predictor = dlib.shape_predictor(checkpoint_dir + '/shape_predictor_68_face_landmarks.dat') func = functools.partial(run, model=model, opt=opt, detector=detector, predictor=predictor, modnet=modnet) func = functools.update_wrapper(func, run) gr.Interface( func, [ gr.inputs.Image(type='file', label='Input Image'), gr.inputs.Textbox(lines=1, label="Landmarks"), ], [ gr.outputs.Image( type='pil', label='Result'), ], # examples=examples, theme=args.theme, title=TITLE, description=DESCRIPTION, article=ARTICLE, allow_screenshot=args.allow_screenshot, allow_flagging=args.allow_flagging, live=args.live, ).launch( enable_queue=args.enable_queue, server_port=args.port, share=args.share, ) if __name__ == '__main__': main()