File size: 13,562 Bytes
d73173f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
from .base_model import BaseModel
from . import networks
import torch


class TestModel(BaseModel):
    def name(self):
        return 'TestModel'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        assert not is_train, 'TestModel cannot be used in train mode'
        # uncomment because default CycleGAN did not use dropout ( parser.set_defaults(no_dropout=True) )
        # parser = CycleGANModel.modify_commandline_options(parser, is_train=False)   
        parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch')# no_lsgan=True, use_lsgan=False
        parser.set_defaults(dataset_mode='single')
        parser.set_defaults(auxiliary_root='auxiliaryeye2o')
        parser.set_defaults(use_local=True, hair_local=True, bg_local=True)
        parser.set_defaults(nose_ae=True, others_ae=True, compactmask=True, MOUTH_H=56)
        parser.set_defaults(soft_border=1)
        parser.add_argument('--nnG_hairc', type=int, default=6, help='nnG for hair classifier')
        parser.add_argument('--use_resnet', action='store_true', help='use resnet for generator')

        parser.add_argument('--model_suffix', type=str, default='',
                            help='In checkpoints_dir, [which_epoch]_net_G[model_suffix].pth will'
                            ' be loaded as the generator of TestModel')

        return parser

    def initialize(self, opt):
        assert(not opt.isTrain)
        BaseModel.initialize(self, opt)

        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = []
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        self.visual_names = ['real_A', 'fake_B']
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        self.model_names = ['G' + opt.model_suffix]
        self.auxiliary_model_names = []
        if self.opt.use_local:
            self.model_names += ['GLEyel','GLEyer','GLNose','GLMouth','GLHair','GLBG','GCombine']
            self.auxiliary_model_names += ['CLm','CLh']
            # auxiliary nets for local output refinement
            if self.opt.nose_ae:
                self.auxiliary_model_names += ['AE']
            if self.opt.others_ae:
                self.auxiliary_model_names += ['AEel','AEer','AEmowhite','AEmoblack']
        print('model_names', self.model_names)
        print('auxiliary_model_names', self.auxiliary_model_names)

        # load/define networks
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
                                      opt.nnG)
        print('netG', opt.netG)
        if self.opt.use_local:
            netlocal1 = 'partunet' if self.opt.use_resnet == 0 else 'resnet_nblocks'
            netlocal2 = 'partunet2' if self.opt.use_resnet == 0 else 'resnet_6blocks'
            netlocal2_style = 'partunet2style' if self.opt.use_resnet == 0 else 'resnet_style2_6blocks'
            self.netGLEyel = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
            self.netGLEyer = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
            self.netGLNose = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
            self.netGLMouth = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
            self.netGLHair = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal2_style, opt.norm,
                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=4,
                                      extra_channel=3)
            self.netGLBG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal2, opt.norm,
                                    not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=4)
            # by default combiner_type is combiner, which uses resnet
            print('combiner_type', self.opt.combiner_type)
            self.netGCombine = networks.define_G(2*opt.output_nc, opt.output_nc, opt.ngf, self.opt.combiner_type, opt.norm,
                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 2)
            # auxiliary classifiers for mouth and hair
            ratio = self.opt.fineSize / 256
            self.MOUTH_H = int(self.opt.MOUTH_H * ratio)
            self.MOUTH_W = int(self.opt.MOUTH_W * ratio)
            self.netCLm = networks.define_G(opt.input_nc, 2, opt.ngf, 'classifier', opt.norm,
                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
                                      nnG = 3, ae_h = self.MOUTH_H, ae_w = self.MOUTH_W)
            self.netCLh = networks.define_G(opt.input_nc, 3, opt.ngf, 'classifier', opt.norm,
                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
                                      nnG = opt.nnG_hairc, ae_h = opt.fineSize, ae_w = opt.fineSize)
        # ==================================auxiliary nets (loaded, parameters fixed)=============================
        if self.opt.use_local and self.opt.nose_ae:
            ratio = self.opt.fineSize / 256
            NOSE_H = self.opt.NOSE_H * ratio
            NOSE_W = self.opt.NOSE_W * ratio
            self.netAE = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
                                       latent_dim=self.opt.ae_latentno, ae_h=NOSE_H, ae_w=NOSE_W)
            self.set_requires_grad(self.netAE, False)            
        if self.opt.use_local and self.opt.others_ae:
            ratio = self.opt.fineSize / 256
            EYE_H = self.opt.EYE_H * ratio
            EYE_W = self.opt.EYE_W * ratio
            MOUTH_H = self.opt.MOUTH_H * ratio
            MOUTH_W = self.opt.MOUTH_W * ratio
            self.netAEel = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
                                      latent_dim=self.opt.ae_latenteye, ae_h=EYE_H, ae_w=EYE_W)
            self.netAEer = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
                                      latent_dim=self.opt.ae_latenteye, ae_h=EYE_H, ae_w=EYE_W)
            self.netAEmowhite = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
                                      latent_dim=self.opt.ae_latentmo, ae_h=MOUTH_H, ae_w=MOUTH_W)
            self.netAEmoblack = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
                                      latent_dim=self.opt.ae_latentmo, ae_h=MOUTH_H, ae_w=MOUTH_W)
            self.set_requires_grad(self.netAEel, False)
            self.set_requires_grad(self.netAEer, False)
            self.set_requires_grad(self.netAEmowhite, False)
            self.set_requires_grad(self.netAEmoblack, False)

        # assigns the model to self.netG_[suffix] so that it can be loaded
        # please see BaseModel.load_networks
        setattr(self, 'netG' + opt.model_suffix, self.netG)

    def set_input(self, input):
        # we need to use single_dataset mode
        self.real_A = input['A'].to(self.device)
        self.image_paths = input['A_paths']
        self.batch_size = len(self.image_paths)
        if self.opt.use_local:
            self.real_A_eyel = input['eyel_A'].to(self.device)
            self.real_A_eyer = input['eyer_A'].to(self.device)
            self.real_A_nose = input['nose_A'].to(self.device)
            self.real_A_mouth = input['mouth_A'].to(self.device)
            self.center = input['center']
            if self.opt.soft_border:
                self.softel = input['soft_eyel_mask'].to(self.device)
                self.softer = input['soft_eyer_mask'].to(self.device)
                self.softno = input['soft_nose_mask'].to(self.device)
                self.softmo = input['soft_mouth_mask'].to(self.device)
            if self.opt.compactmask:
                self.cmask = input['cmask'].to(self.device)
                self.cmask1 = self.cmask*2-1#[0,1]->[-1,1]
                self.cmaskel = input['cmaskel'].to(self.device)
                self.cmask1el = self.cmaskel*2-1
                self.cmasker = input['cmasker'].to(self.device)
                self.cmask1er = self.cmasker*2-1
                self.cmaskmo = input['cmaskmo'].to(self.device)
                self.cmask1mo = self.cmaskmo*2-1
            self.real_A_hair = input['hair_A'].to(self.device)
            self.mask = input['mask'].to(self.device) # mask for non-eyes,nose,mouth
            self.mask2 = input['mask2'].to(self.device) # mask for non-bg
            self.real_A_bg = input['bg_A'].to(self.device)

    def getonehot(self,outputs,classes):
        [maxv,index] = torch.max(outputs,1)
        y = torch.unsqueeze(index,1)
        onehot = torch.FloatTensor(self.batch_size,classes).to(self.device)
        onehot.zero_()
        onehot.scatter_(1,y,1)
        return onehot  

    def forward(self):
        if not self.opt.use_local:
            self.fake_B = self.netG(self.real_A)
        else:
            self.fake_B0 = self.netG(self.real_A)
            # EYES, MOUTH
            outputs1 = self.netCLm(self.real_A_mouth)
            onehot1 = self.getonehot(outputs1,2)

            if not self.opt.others_ae:
                fake_B_eyel = self.netGLEyel(self.real_A_eyel)
                fake_B_eyer = self.netGLEyer(self.real_A_eyer)
                fake_B_mouth = self.netGLMouth(self.real_A_mouth)
            else: # use AE that only constains compact region, need cmask!
                self.fake_B_eyel1 = self.netGLEyel(self.real_A_eyel)
                self.fake_B_eyer1 = self.netGLEyer(self.real_A_eyer)
                self.fake_B_mouth1 = self.netGLMouth(self.real_A_mouth)
                self.fake_B_eyel2,_ = self.netAEel(self.fake_B_eyel1)
                self.fake_B_eyer2,_ = self.netAEer(self.fake_B_eyer1)
                # USE 2 AEs
                self.fake_B_mouth2 = torch.FloatTensor(self.batch_size,self.opt.output_nc,self.MOUTH_H,self.MOUTH_W).to(self.device)
                for i in range(self.batch_size):
                    if onehot1[i][0] == 1:
                        self.fake_B_mouth2[i],_ = self.netAEmowhite(self.fake_B_mouth1[i].unsqueeze(0))
                        #print('AEmowhite')
                    elif onehot1[i][1] == 1:
                        self.fake_B_mouth2[i],_ = self.netAEmoblack(self.fake_B_mouth1[i].unsqueeze(0))
                        #print('AEmoblack')
                fake_B_eyel = self.add_with_mask(self.fake_B_eyel2,self.fake_B_eyel1,self.cmaskel)
                fake_B_eyer = self.add_with_mask(self.fake_B_eyer2,self.fake_B_eyer1,self.cmasker)
                fake_B_mouth = self.add_with_mask(self.fake_B_mouth2,self.fake_B_mouth1,self.cmaskmo)
            # NOSE
            if not self.opt.nose_ae:
                fake_B_nose = self.netGLNose(self.real_A_nose)
            else: # use AE that only constains compact region, need cmask!
                self.fake_B_nose1 = self.netGLNose(self.real_A_nose)
                self.fake_B_nose2,_ = self.netAE(self.fake_B_nose1)
                fake_B_nose = self.add_with_mask(self.fake_B_nose2,self.fake_B_nose1,self.cmask)
            
            # HAIR, BG AND PARTCOMBINE
            outputs2 = self.netCLh(self.real_A_hair)
            onehot2 = self.getonehot(outputs2,3)

            fake_B_hair = self.netGLHair(self.real_A_hair,onehot2)
            fake_B_bg = self.netGLBG(self.real_A_bg)
            self.fake_B_hair = self.masked(fake_B_hair,self.mask*self.mask2)
            self.fake_B_bg = self.masked(fake_B_bg,self.inverse_mask(self.mask2))
            if not self.opt.compactmask:
                self.fake_B1 = self.partCombiner2_bg(fake_B_eyel,fake_B_eyer,fake_B_nose,fake_B_mouth,fake_B_hair,fake_B_bg,self.mask*self.mask2,self.inverse_mask(self.mask2),self.opt.comb_op)
            else:
                self.fake_B1 = self.partCombiner2_bg(fake_B_eyel,fake_B_eyer,fake_B_nose,fake_B_mouth,fake_B_hair,fake_B_bg,self.mask*self.mask2,self.inverse_mask(self.mask2),self.opt.comb_op,self.opt.region_enm,self.cmaskel,self.cmasker,self.cmask,self.cmaskmo)
            
            self.fake_B = self.netGCombine(torch.cat([self.fake_B0,self.fake_B1],1))