hylee commited on
Commit
e12fee9
1 Parent(s): 2d16fc9
Files changed (1) hide show
  1. modnet.py +79 -0
modnet.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import argparse
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ import onnx
8
+ import onnxruntime
9
+
10
+
11
+ class ModNet:
12
+
13
+ def __init__(self, model_path):
14
+ # Initialize session and get prediction
15
+ self.session = onnxruntime.InferenceSession(model_path, None)
16
+
17
+ # Get x_scale_factor & y_scale_factor to resize image
18
+ def get_scale_factor(im_h, im_w, ref_size):
19
+
20
+ if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
21
+ if im_w >= im_h:
22
+ im_rh = ref_size
23
+ im_rw = int(im_w / im_h * ref_size)
24
+ elif im_w < im_h:
25
+ im_rw = ref_size
26
+ im_rh = int(im_h / im_w * ref_size)
27
+ else:
28
+ im_rh = im_h
29
+ im_rw = im_w
30
+
31
+ im_rw = im_rw - im_rw % 32
32
+ im_rh = im_rh - im_rh % 32
33
+
34
+ x_scale_factor = im_rw / im_w
35
+ y_scale_factor = im_rh / im_h
36
+
37
+ return x_scale_factor, y_scale_factor
38
+
39
+ def segment(self, image_path, output_path):
40
+ ref_size = 512
41
+ ##############################################
42
+ # Main Inference part
43
+ ##############################################
44
+
45
+ # read image
46
+ im = cv2.imread(image_path)
47
+ im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
48
+
49
+ # unify image channels to 3
50
+ if len(im.shape) == 2:
51
+ im = im[:, :, None]
52
+ if im.shape[2] == 1:
53
+ im = np.repeat(im, 3, axis=2)
54
+ elif im.shape[2] == 4:
55
+ im = im[:, :, 0:3]
56
+
57
+ # normalize values to scale it between -1 to 1
58
+ im = (im - 127.5) / 127.5
59
+
60
+ im_h, im_w, im_c = im.shape
61
+ x, y = get_scale_factor(im_h, im_w, ref_size)
62
+
63
+ # resize image
64
+ im = cv2.resize(im, None, fx=x, fy=y, interpolation=cv2.INTER_AREA)
65
+
66
+ # prepare input shape
67
+ im = np.transpose(im)
68
+ im = np.swapaxes(im, 1, 2)
69
+ im = np.expand_dims(im, axis=0).astype('float32')
70
+
71
+ input_name = self.session.get_inputs()[0].name
72
+ output_name = self.session.get_outputs()[0].name
73
+ result = self.session.run([output_name], {input_name: im})
74
+
75
+ # refine matte
76
+ matte = (np.squeeze(result[0]) * 255).astype('uint8')
77
+ matte = cv2.resize(matte, dsize=(im_w, im_h), interpolation=cv2.INTER_AREA)
78
+
79
+ cv2.imwrite(output_path, matte)