hylee commited on
Commit
0776c7c
1 Parent(s): f306a44
Files changed (2) hide show
  1. app.py +20 -16
  2. requirements.txt +3 -1
app.py CHANGED
@@ -29,6 +29,7 @@ from util import html
29
 
30
  import ntpath
31
  from util import util
 
32
 
33
  ORIGINAL_REPO_URL = 'https://github.com/yiranran/APDrawingGAN2'
34
  TITLE = 'yiranran/APDrawingGAN2'
@@ -64,17 +65,15 @@ def load_checkpoint():
64
  force_filename='checkpoints.zip')
65
  print(checkpoint_path)
66
  shutil.unpack_archive(checkpoint_path, extract_dir=dir)
67
-
68
  print(os.listdir(dir + '/checkpoints'))
69
-
70
- dataset_path = huggingface_hub.hf_hub_download(MODEL_REPO,
71
- 'dataset.zip',
72
- force_filename='dataset.zip')
73
- print(checkpoint_path)
74
- shutil.unpack_archive(dataset_path)
75
-
76
  return dir + '/checkpoints'
77
 
 
 
 
 
 
 
78
 
79
  # save image to the disk
80
  def save_images2(image_dir, visuals, image_path, aspect_ratio=1.0, width=256):
@@ -124,6 +123,7 @@ def run(
124
  opt,
125
  detector,
126
  predictor,
 
127
  ) -> tuple[PIL.Image.Image,PIL.Image.Image]:
128
 
129
  dataroot = 'images/' + compress_UUID()
@@ -144,9 +144,9 @@ def run(
144
  fullname = os.path.basename(image.name)
145
  name = fullname.split(".")[0]
146
 
147
- bg = cv2.cvtColor(cv2.imread(mask.name), cv2.COLOR_BGR2GRAY)
148
- cv2.imwrite(os.path.join(opt.bg_dir, name+'.png'), bg)
149
- #shutil.copyfile(mask.name, os.path.join(opt.bg_dir, name+'.png'))
150
 
151
  imgfile = os.path.join(opt.dataroot, fullname)
152
  lmfile5 = os.path.join(opt.lm_dir, name+'.txt')
@@ -155,10 +155,10 @@ def run(
155
  get_68lm(imgfile, lmfile5, lmfile68, detector, predictor)
156
 
157
  imgs = []
158
- # for part in ['eyel', 'eyer', 'nose', 'mouth']:
159
- # savepath = os.path.join(opt.cmask_dir + part, name+'.png')
160
- # get_partmask(imgfile, part, lmfile68, savepath)
161
- # #imgs.append(savepath)
162
 
163
  data_loader = CreateDataLoader(opt)
164
  dataset = data_loader.load_data()
@@ -201,13 +201,17 @@ def main():
201
  model = create_model(opt)
202
  model.setup(opt)
203
 
 
 
 
 
204
  '''
205
  预处理数据
206
  '''
207
  detector = dlib.get_frontal_face_detector()
208
  predictor = dlib.shape_predictor(checkpoint_dir + '/shape_predictor_68_face_landmarks.dat')
209
 
210
- func = functools.partial(run, model=model, opt=opt, detector=detector, predictor=predictor)
211
  func = functools.update_wrapper(func, run)
212
 
213
  gr.Interface(
 
29
 
30
  import ntpath
31
  from util import util
32
+ from modnet import ModNet
33
 
34
  ORIGINAL_REPO_URL = 'https://github.com/yiranran/APDrawingGAN2'
35
  TITLE = 'yiranran/APDrawingGAN2'
 
65
  force_filename='checkpoints.zip')
66
  print(checkpoint_path)
67
  shutil.unpack_archive(checkpoint_path, extract_dir=dir)
 
68
  print(os.listdir(dir + '/checkpoints'))
 
 
 
 
 
 
 
69
  return dir + '/checkpoints'
70
 
71
+ def load_modnet_model():
72
+ modnet_path = huggingface_hub.hf_hub_download(MODEL_REPO,
73
+ 'modnet.onnx',
74
+ force_filename='modnet.onnx')
75
+ return modnet_path
76
+
77
 
78
  # save image to the disk
79
  def save_images2(image_dir, visuals, image_path, aspect_ratio=1.0, width=256):
 
123
  opt,
124
  detector,
125
  predictor,
126
+ modnet : ModNet,
127
  ) -> tuple[PIL.Image.Image,PIL.Image.Image]:
128
 
129
  dataroot = 'images/' + compress_UUID()
 
144
  fullname = os.path.basename(image.name)
145
  name = fullname.split(".")[0]
146
 
147
+ #bg = cv2.cvtColor(cv2.imread(mask.name), cv2.COLOR_BGR2GRAY)
148
+ #cv2.imwrite(os.path.join(opt.bg_dir, name+'.png'), bg)
149
+ modnet.segment(image.name, os.path.join(opt.bg_dir, name+'.png'))
150
 
151
  imgfile = os.path.join(opt.dataroot, fullname)
152
  lmfile5 = os.path.join(opt.lm_dir, name+'.txt')
 
155
  get_68lm(imgfile, lmfile5, lmfile68, detector, predictor)
156
 
157
  imgs = []
158
+ for part in ['eyel', 'eyer', 'nose', 'mouth']:
159
+ savepath = os.path.join(opt.cmask_dir + part, name+'.png')
160
+ get_partmask(imgfile, part, lmfile68, savepath)
161
+ #imgs.append(savepath)
162
 
163
  data_loader = CreateDataLoader(opt)
164
  dataset = data_loader.load_data()
 
201
  model = create_model(opt)
202
  model.setup(opt)
203
 
204
+ modnet_path = load_modnet_model();
205
+ modnet = ModNet(modnet_path)
206
+
207
+
208
  '''
209
  预处理数据
210
  '''
211
  detector = dlib.get_frontal_face_detector()
212
  predictor = dlib.shape_predictor(checkpoint_dir + '/shape_predictor_68_face_landmarks.dat')
213
 
214
+ func = functools.partial(run, model=model, opt=opt, detector=detector, predictor=predictor, modnet=modnet)
215
  func = functools.update_wrapper(func, run)
216
 
217
  gr.Interface(
requirements.txt CHANGED
@@ -7,4 +7,6 @@ numpy==1.16.4
7
  pillow<7.0.0
8
  opencv-python==4.1.0.25
9
  dlib==19.18.0
10
- shapely==1.7.0
 
 
 
7
  pillow<7.0.0
8
  opencv-python==4.1.0.25
9
  dlib==19.18.0
10
+ shapely==1.7.0
11
+ onnx==1.8.1
12
+ onnxruntime==1.6.0