apdrawing / app.py
hylee's picture
init
3caa29f
raw
history blame
8.25 kB
#!/usr/bin/env python
from __future__ import annotations
import argparse
import functools
import os
import pathlib
import sys
from typing import Callable
import uuid
import cv2
sys.path.insert(0, 'APDrawingGAN2')
import gradio as gr
import huggingface_hub
import numpy as np
import PIL.Image
from io import BytesIO
import shutil
from options.test_options import TestOptions
from data import CreateDataLoader
from models import create_model
import dlib
from preprocess.get_partmask import get_68lm, get_partmask
from util import html
import ntpath
from util import util
ORIGINAL_REPO_URL = 'https://github.com/yiranran/APDrawingGAN2'
TITLE = 'yiranran/APDrawingGAN2'
DESCRIPTION = f"""This is a demo for {ORIGINAL_REPO_URL}.
"""
ARTICLE = """
"""
MODEL_REPO = 'hylee/apdrawing_model'
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--theme', type=str)
parser.add_argument('--live', action='store_true')
parser.add_argument('--share', action='store_true')
parser.add_argument('--port', type=int)
parser.add_argument('--disable-queue',
dest='enable_queue',
action='store_false')
parser.add_argument('--allow-flagging', type=str, default='never')
parser.add_argument('--allow-screenshot', action='store_true')
return parser.parse_args()
def load_checkpoint():
dir = 'checkpoint'
checkpoint_path = huggingface_hub.hf_hub_download(MODEL_REPO,
'checkpoints.zip',
force_filename='checkpoints.zip')
print(checkpoint_path)
shutil.unpack_archive(checkpoint_path, extract_dir=dir)
print(os.listdir(dir + '/checkpoints'))
dataset_path = huggingface_hub.hf_hub_download(MODEL_REPO,
'dataset.zip',
force_filename='dataset.zip')
print(checkpoint_path)
shutil.unpack_archive(dataset_path)
return dir + '/checkpoints'
# save image to the disk
def save_images2(image_dir, visuals, image_path, aspect_ratio=1.0, width=256):
short_path = ntpath.basename(image_path[0])
name = os.path.splitext(short_path)[0]
imgs = []
for label, im_data in visuals.items():
im = util.tensor2im(im_data) # tensor to numpy array [-1,1]->[0,1]->[0,255]
image_name = '%s_%s.png' % (name, label)
save_path = os.path.join(image_dir, image_name)
h, w, _ = im.shape
if aspect_ratio > 1.0:
im = np.array(PIL.Image.fromarray(arr).resize(im, (h, int(w * aspect_ratio))))
if aspect_ratio < 1.0:
im = np.array(PIL.Image.fromarray(arr).resize(im, (int(h / aspect_ratio), w)))
util.save_image(im, save_path)
imgs.append(save_path)
return imgs
SAFEHASH = [x for x in "0123456789-abcdefghijklmnopqrstuvwxyz_ABCDEFGHIJKLMNOPQRSTUVWXYZ"]
def compress_UUID():
'''
根据http://www.ietf.org/rfc/rfc1738.txt,由uuid编码扩bai大字符域生成du串
包括:[0-9a-zA-Z\-_]共64个
长度:(32-2)/3*2=20
备注:可在地球上人zhi人都用,使用100年不重复(2^120)
:return:String
'''
row = str(uuid.uuid4()).replace('-', '')
safe_code = ''
for i in range(10):
enbin = "%012d" % int(bin(int(row[i * 3] + row[i * 3 + 1] + row[i * 3 + 2], 16))[2:], 10)
safe_code += (SAFEHASH[int(enbin[0:6], 2)] + SAFEHASH[int(enbin[6:12], 2)])
safe_code = safe_code.replace('-', '')
return safe_code
def run(
image,
mask,
model,
opt,
detector,
predictor,
) -> tuple[PIL.Image.Image,PIL.Image.Image,PIL.Image.Image,PIL.Image.Image,]:
dataroot = 'images/' + compress_UUID()
opt.dataroot = os.path.join(dataroot, 'src/')
os.makedirs(opt.dataroot, exist_ok=True)
opt.results_dir = os.path.join(dataroot, 'results/')
os.makedirs(opt.results_dir, exist_ok=True)
opt.lm_dir = os.path.join(dataroot, 'landmark/')
opt.cmask_dir = os.path.join(dataroot, 'mask/')
opt.bg_dir = os.path.join(dataroot, 'mask/bg')
os.makedirs(opt.lm_dir, exist_ok=True)
os.makedirs(opt.cmask_dir, exist_ok=True)
os.makedirs(opt.bg_dir, exist_ok=True)
shutil.copy(image.name, opt.dataroot)
fullname = os.path.basename(image.name)
name = fullname.split(".")[0]
bg = cv2.cvtColor(cv2.imread(mask.name), cv2.COLOR_BGR2GRAY)
cv2.imwrite(os.path.join(opt.bg_dir, name+'.png'), bg)
#shutil.copyfile(mask.name, os.path.join(opt.bg_dir, name+'.png'))
imgfile = os.path.join(opt.dataroot, fullname)
lmfile5 = os.path.join(opt.lm_dir, name+'.txt')
lmfile68 = os.path.join(opt.lm_dir, name + '_68.txt')
# 预处理数据
get_68lm(imgfile, lmfile5, lmfile68, detector, predictor)
imgs = []
for part in ['eyel', 'eyer', 'nose', 'mouth']:
savepath = os.path.join(opt.cmask_dir + part, name+'.png')
get_partmask(imgfile, part, lmfile68, savepath)
#imgs.append(savepath)
# '''
# python test.py --dataroot dataset/test_single --name apdrawinggan++_author --model test --use_resnet --netG resnet_9blocks --which_epoch 150 --how_many 1000 --gpu_ids 0 --gpu_ids_p 0 --imagefolder images-single
# '''
# opt.dataroot='dataset/test_single'
# opt.name = 'apdrawinggan++_author'
# opt.model = 'test'
# opt.use_resnet = True
# opt.netG = 'resnet_9blocks'
# opt.which_epoch = 150
# opt.how_many = 1000
# opt.gpu_ids = '-1'
# opt.imagefolder = 'images-single'
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
# test
# model.eval()
for i, data in enumerate(dataset):
if i >= opt.how_many: # test code only supports batch_size = 1, how_many means how many test images to run
break
model.set_input(data)
model.test()
visuals = model.get_current_visuals() # in test the loadSize is set to the same as fineSize
img_path = model.get_image_paths()
# if i % 5 == 0:
# print('processing (%04d)-th image... %s' % (i, img_path))
imgs = save_images2(opt.results_dir, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
print(imgs)
return PIL.Image.open(imgs[0]),PIL.Image.open(imgs[1])
def main():
gr.close_all()
args = parse_args()
checkpoint_dir = load_checkpoint()
opt = TestOptions().parse()
opt.num_threads = 1 # test code only supports num_threads = 1
opt.batch_size = 1 # test code only supports batch_size = 1
opt.serial_batches = True # no shuffle
opt.no_flip = True # no flip
opt.display_id = -1 # no visdom display
opt.checkpoints_dir = checkpoint_dir
model = create_model(opt)
model.setup(opt)
'''
预处理数据
'''
detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor(checkpoint_dir + '/shape_predictor_68_face_landmarks.dat')
func = functools.partial(run, model=model, opt=opt, detector=detector, predictor=predictor)
func = functools.update_wrapper(func, run)
gr.Interface(
func,
[
gr.inputs.Image(type='file', label='Input Image'),
gr.inputs.Image(type='file', label='Input Mask'),
],
[
gr.outputs.Image(
type='pil',
label='Result'),
gr.outputs.Image(
type='pil',
label='Result'),
gr.outputs.Image(
type='pil',
label='Result'),
gr.outputs.Image(
type='pil',
label='Result'),
],
# examples=examples,
theme=args.theme,
title=TITLE,
description=DESCRIPTION,
article=ARTICLE,
allow_screenshot=args.allow_screenshot,
allow_flagging=args.allow_flagging,
live=args.live,
).launch(
enable_queue=args.enable_queue,
server_port=args.port,
share=args.share,
)
if __name__ == '__main__':
main()