hylee's picture
init
5c971ea
raw
history blame
No virus
28.7 kB
import os
import torch
from collections import OrderedDict
from . import networks
class BaseModel():
# modify parser to add command line options,
# and also change the default values if needed
@staticmethod
def modify_commandline_options(parser, is_train):
return parser
def name(self):
return 'BaseModel'
def initialize(self, opt):
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.gpu_ids_p = opt.gpu_ids_p
self.isTrain = opt.isTrain
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
self.device_p = torch.device('cuda:{}'.format(self.gpu_ids_p[0])) if self.gpu_ids else torch.device('cpu')
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
self.auxiliary_dir = os.path.join(opt.checkpoints_dir, opt.auxiliary_root)
if opt.resize_or_crop != 'scale_width':
torch.backends.cudnn.benchmark = True
self.loss_names = []
self.model_names = []
self.visual_names = []
self.image_paths = []
def set_input(self, input):
self.input = input
def forward(self):
pass
# load and print networks; create schedulers
def setup(self, opt, parser=None):
if self.isTrain:
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
if not self.isTrain or opt.continue_train:
self.load_networks(opt.which_epoch)
if len(self.auxiliary_model_names) > 0:
self.load_auxiliary_networks()
self.print_networks(opt.verbose)
# make models eval mode during test time
def eval(self):
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
net.eval()
# used in test time, wrapping `forward` in no_grad() so we don't save
# intermediate steps for backprop
def test(self):
with torch.no_grad():
self.forward()
# get image paths
def get_image_paths(self):
return self.image_paths
def optimize_parameters(self):
pass
# update learning rate (called once every epoch)
def update_learning_rate(self):
for scheduler in self.schedulers:
scheduler.step()
lr = self.optimizers[0].param_groups[0]['lr']
print('learning rate = %.7f' % lr)
# return visualization images. train.py will display these images, and save the images to a html
def get_current_visuals(self):
visual_ret = OrderedDict()
for name in self.visual_names:
if isinstance(name, str):
visual_ret[name] = getattr(self, name)
return visual_ret
# return traning losses/errors. train.py will print out these errors as debugging information
def get_current_losses(self):
errors_ret = OrderedDict()
for name in self.loss_names:
if isinstance(name, str):
# float(...) works for both scalar tensor and float number
errors_ret[name] = float(getattr(self, 'loss_' + name))
return errors_ret
# save models to the disk
def save_networks(self, which_epoch):
for name in self.model_names:
if isinstance(name, str):
save_filename = '%s_net_%s.pth' % (which_epoch, name)
save_path = os.path.join(self.save_dir, save_filename)
net = getattr(self, 'net' + name)
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
torch.save(net.module.cpu().state_dict(), save_path)
net.cuda(self.gpu_ids[0])
else:
torch.save(net.cpu().state_dict(), save_path)
def save_networks2(self, which_epoch):
gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch))
dis_name = os.path.join(self.save_dir, '%s_net_dis.pt' % (which_epoch))
dict_gen = {}
dict_dis = {}
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
state_dict = net.module.cpu().state_dict()
net.cuda(self.gpu_ids[0])
else:
state_dict = net.cpu().state_dict()
if name[0] == 'G':
dict_gen[name] = state_dict
elif name[0] == 'D':
dict_dis[name] = state_dict
else:
save_filename = '%s_net_%s.pth' % (which_epoch, name)
save_path = os.path.join(self.save_dir, save_filename)
torch.save(state_dict, save_path)
if dict_gen:
torch.save(dict_gen, gen_name)
if dict_dis:
torch.save(dict_dis, dis_name)
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'num_batches_tracked'):
state_dict.pop('.'.join(keys))
else:
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
# load models from the disk
def load_networks(self, which_epoch):
gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch))
if os.path.exists(gen_name):
self.load_networks2(which_epoch)
return
for name in self.model_names:
if isinstance(name, str):
load_filename = '%s_net_%s.pth' % (which_epoch, name)
load_path = os.path.join(self.save_dir, load_filename)
net = getattr(self, 'net' + name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
print('loading the model from %s' % load_path)
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
state_dict = torch.load(load_path, map_location=str(self.device))
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
# patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
net.load_state_dict(state_dict)
def load_networks2(self, which_epoch):
gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch))
gen_state_dict = torch.load(gen_name, map_location=str(self.device))
if self.isTrain and self.opt.model != 'apdrawing_style_nogan':
dis_name = os.path.join(self.save_dir, '%s_net_dis.pt' % (which_epoch))
dis_state_dict = torch.load(dis_name, map_location=str(self.device))
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
if name[0] == 'G':
print('loading the model %s from %s' % (name, gen_name))
state_dict = gen_state_dict[name]
elif name[0] == 'D':
print('loading the model %s from %s' % (name, gen_name))
state_dict = dis_state_dict[name]
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
# patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
net.load_state_dict(state_dict)
# load auxiliary net models from the disk
def load_auxiliary_networks(self):
for name in self.auxiliary_model_names:
if isinstance(name, str):
if 'AE' in name and self.opt.ae_small:
load_filename = '%s_net_%s_small.pth' % ('latest', name)
elif 'Regressor' in name:
load_filename = '%s_net_%s%d.pth' % ('latest', name, self.opt.regarch)
else:
load_filename = '%s_net_%s.pth' % ('latest', name)
load_path = os.path.join(self.auxiliary_dir, load_filename)
net = getattr(self, 'net' + name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
print('loading the model from %s' % load_path)
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
if name in ['DT1', 'DT2', 'Line1', 'Line2', 'Continuity1', 'Continuity2', 'Regressor', 'Regressorhair',
'Regressorface']:
state_dict = torch.load(load_path, map_location=str(self.device_p))
else:
state_dict = torch.load(load_path, map_location=str(self.device))
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
# patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
net.load_state_dict(state_dict)
# print network information
def print_networks(self, verbose):
print('---------- Networks initialized -------------')
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
num_params = 0
for param in net.parameters():
num_params += param.numel()
if verbose:
print(net)
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
print('-----------------------------------------------')
# set requies_grad=Fasle to avoid computation
def set_requires_grad(self, nets, requires_grad=False):
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
# =============================================================================================================
def inverse_mask(self, mask):
return torch.ones(mask.shape).to(self.device) - mask
def masked(self, A, mask):
return (A / 2 + 0.5) * mask * 2 - 1
def add_with_mask(self, A, B, mask):
return ((A / 2 + 0.5) * mask + (B / 2 + 0.5) * (torch.ones(mask.shape).to(self.device) - mask)) * 2 - 1
def addone_with_mask(self, A, mask):
return ((A / 2 + 0.5) * mask + (torch.ones(mask.shape).to(self.device) - mask)) * 2 - 1
def partCombiner(self, eyel, eyer, nose, mouth, average_pos=False, comb_op=1, region_enm=0, cmaskel=None,
cmasker=None, cmaskno=None, cmaskmo=None):
'''
x y
100.571 123.429
155.429 123.429
128.000 155.886
103.314 185.417
152.686 185.417
this is the mean locaiton of 5 landmarks (for 256x256)
Pad2d Left,Right,Top,Down
'''
if comb_op == 0:
# use max pooling, pad black for eyes etc
padvalue = -1
if region_enm in [1, 2]:
eyel = eyel * cmaskel
eyer = eyer * cmasker
nose = nose * cmaskno
mouth = mouth * cmaskmo
else:
# use min pooling, pad white for eyes etc
padvalue = 1
if region_enm in [1, 2]:
eyel = self.addone_with_mask(eyel, cmaskel)
eyer = self.addone_with_mask(eyer, cmasker)
nose = self.addone_with_mask(nose, cmaskno)
mouth = self.addone_with_mask(mouth, cmaskmo)
if region_enm in [0, 1]: # need to pad
IMAGE_SIZE = self.opt.fineSize
ratio = IMAGE_SIZE / 256
EYE_W = self.opt.EYE_W * ratio
EYE_H = self.opt.EYE_H * ratio
NOSE_W = self.opt.NOSE_W * ratio
NOSE_H = self.opt.NOSE_H * ratio
MOUTH_W = self.opt.MOUTH_W * ratio
MOUTH_H = self.opt.MOUTH_H * ratio
bs, nc, _, _ = eyel.shape
eyel_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
eyer_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
nose_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
mouth_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
for i in range(bs):
if not average_pos:
center = self.center[i] # x,y
else: # if average_pos = True
center = torch.tensor([[101, 123 - 4], [155, 123 - 4], [128, 156 - NOSE_H / 2 + 16], [128, 185]])
eyel_p[i] = torch.nn.ConstantPad2d((int(center[0, 0] - EYE_W / 2 - 1),
int(IMAGE_SIZE - (center[0, 0] + EYE_W / 2 - 1)),
int(center[0, 1] - EYE_H / 2 - 1),
int(IMAGE_SIZE - (center[0, 1] + EYE_H / 2 - 1))), -1)(eyel[i])
eyer_p[i] = torch.nn.ConstantPad2d((int(center[1, 0] - EYE_W / 2 - 1),
int(IMAGE_SIZE - (center[1, 0] + EYE_W / 2 - 1)),
int(center[1, 1] - EYE_H / 2 - 1),
int(IMAGE_SIZE - (center[1, 1] + EYE_H / 2 - 1))), -1)(eyer[i])
nose_p[i] = torch.nn.ConstantPad2d((int(center[2, 0] - NOSE_W / 2 - 1),
int(IMAGE_SIZE - (center[2, 0] + NOSE_W / 2 - 1)),
int(center[2, 1] - NOSE_H / 2 - 1),
int(IMAGE_SIZE - (center[2, 1] + NOSE_H / 2 - 1))), -1)(nose[i])
mouth_p[i] = torch.nn.ConstantPad2d((int(center[3, 0] - MOUTH_W / 2 - 1),
int(IMAGE_SIZE - (center[3, 0] + MOUTH_W / 2 - 1)),
int(center[3, 1] - MOUTH_H / 2 - 1),
int(IMAGE_SIZE - (center[3, 1] + MOUTH_H / 2 - 1))), -1)(mouth[i])
elif region_enm in [2]:
eyel_p = eyel
eyer_p = eyer
nose_p = nose
mouth_p = mouth
if comb_op == 0:
# use max pooling
eyes = torch.max(eyel_p, eyer_p)
eye_nose = torch.max(eyes, nose_p)
result = torch.max(eye_nose, mouth_p)
else:
# use min pooling
eyes = torch.min(eyel_p, eyer_p)
eye_nose = torch.min(eyes, nose_p)
result = torch.min(eye_nose, mouth_p)
return result
def partCombiner2(self, eyel, eyer, nose, mouth, hair, mask, comb_op=1, region_enm=0, cmaskel=None, cmasker=None,
cmaskno=None, cmaskmo=None):
if comb_op == 0:
# use max pooling, pad black for eyes etc
padvalue = -1
hair = self.masked(hair, mask)
if region_enm in [1, 2]:
eyel = eyel * cmaskel
eyer = eyer * cmasker
nose = nose * cmaskno
mouth = mouth * cmaskmo
else:
# use min pooling, pad white for eyes etc
padvalue = 1
hair = self.addone_with_mask(hair, mask)
if region_enm in [1, 2]:
eyel = self.addone_with_mask(eyel, cmaskel)
eyer = self.addone_with_mask(eyer, cmasker)
nose = self.addone_with_mask(nose, cmaskno)
mouth = self.addone_with_mask(mouth, cmaskmo)
if region_enm in [0, 1]: # need to pad
IMAGE_SIZE = self.opt.fineSize
ratio = IMAGE_SIZE / 256
EYE_W = self.opt.EYE_W * ratio
EYE_H = self.opt.EYE_H * ratio
NOSE_W = self.opt.NOSE_W * ratio
NOSE_H = self.opt.NOSE_H * ratio
MOUTH_W = self.opt.MOUTH_W * ratio
MOUTH_H = self.opt.MOUTH_H * ratio
bs, nc, _, _ = eyel.shape
eyel_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
eyer_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
nose_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
mouth_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
for i in range(bs):
center = self.center[i] # x,y
eyel_p[i] = torch.nn.ConstantPad2d((center[0, 0] - EYE_W / 2, IMAGE_SIZE - (center[0, 0] + EYE_W / 2),
center[0, 1] - EYE_H / 2, IMAGE_SIZE - (center[0, 1] + EYE_H / 2)),
padvalue)(eyel[i])
eyer_p[i] = torch.nn.ConstantPad2d((center[1, 0] - EYE_W / 2, IMAGE_SIZE - (center[1, 0] + EYE_W / 2),
center[1, 1] - EYE_H / 2, IMAGE_SIZE - (center[1, 1] + EYE_H / 2)),
padvalue)(eyer[i])
nose_p[i] = torch.nn.ConstantPad2d((center[2, 0] - NOSE_W / 2, IMAGE_SIZE - (center[2, 0] + NOSE_W / 2),
center[2, 1] - NOSE_H / 2,
IMAGE_SIZE - (center[2, 1] + NOSE_H / 2)), padvalue)(nose[i])
mouth_p[i] = torch.nn.ConstantPad2d((center[3, 0] - MOUTH_W / 2,
IMAGE_SIZE - (center[3, 0] + MOUTH_W / 2),
center[3, 1] - MOUTH_H / 2,
IMAGE_SIZE - (center[3, 1] + MOUTH_H / 2)), padvalue)(mouth[i])
elif region_enm in [2]:
eyel_p = eyel
eyer_p = eyer
nose_p = nose
mouth_p = mouth
if comb_op == 0:
# use max pooling
eyes = torch.max(eyel_p, eyer_p)
eye_nose = torch.max(eyes, nose_p)
eye_nose_mouth = torch.max(eye_nose, mouth_p)
result = torch.max(hair, eye_nose_mouth)
else:
# use min pooling
eyes = torch.min(eyel_p, eyer_p)
eye_nose = torch.min(eyes, nose_p)
eye_nose_mouth = torch.min(eye_nose, mouth_p)
result = torch.min(hair, eye_nose_mouth)
return result
def partCombiner2_bg(self, eyel, eyer, nose, mouth, hair, bg, maskh, maskb, comb_op=1, region_enm=0, cmaskel=None,
cmasker=None, cmaskno=None, cmaskmo=None):
if comb_op == 0:
# use max pooling, pad black for eyes etc
padvalue = -1
hair = self.masked(hair, maskh)
bg = self.masked(bg, maskb)
if region_enm in [1, 2]:
eyel = eyel * cmaskel
eyer = eyer * cmasker
nose = nose * cmaskno
mouth = mouth * cmaskmo
else:
# use min pooling, pad white for eyes etc
padvalue = 1
hair = self.addone_with_mask(hair, maskh)
bg = self.addone_with_mask(bg, maskb)
if region_enm in [1, 2]:
eyel = self.addone_with_mask(eyel, cmaskel)
eyer = self.addone_with_mask(eyer, cmasker)
nose = self.addone_with_mask(nose, cmaskno)
mouth = self.addone_with_mask(mouth, cmaskmo)
if region_enm in [0, 1]: # need to pad to full size
IMAGE_SIZE = self.opt.fineSize
ratio = IMAGE_SIZE / 256
EYE_W = self.opt.EYE_W * ratio
EYE_H = self.opt.EYE_H * ratio
NOSE_W = self.opt.NOSE_W * ratio
NOSE_H = self.opt.NOSE_H * ratio
MOUTH_W = self.opt.MOUTH_W * ratio
MOUTH_H = self.opt.MOUTH_H * ratio
bs, nc, _, _ = eyel.shape
eyel_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
eyer_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
nose_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
mouth_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
for i in range(bs):
center = self.center[i] # x,y
eyel_p[i] = torch.nn.ConstantPad2d((int(center[0, 0] - EYE_W / 2),
IMAGE_SIZE - int(center[0, 0] + EYE_W / 2),
int(center[0, 1] - EYE_H / 2),
IMAGE_SIZE - int(center[0, 1] + EYE_H / 2)), padvalue)(eyel[i])
eyer_p[i] = torch.nn.ConstantPad2d((int(center[1, 0] - EYE_W / 2),
IMAGE_SIZE - int(center[1, 0] + EYE_W / 2),
int(center[1, 1] - EYE_H / 2),
IMAGE_SIZE - int(center[1, 1] + EYE_H / 2)), padvalue)(eyer[i])
nose_p[i] = torch.nn.ConstantPad2d((int(center[2, 0] - NOSE_W / 2),
IMAGE_SIZE - int(center[2, 0] + NOSE_W / 2),
int(center[2, 1] - NOSE_H / 2),
IMAGE_SIZE - int(center[2, 1] + NOSE_H / 2)), padvalue)(nose[i])
mouth_p[i] = torch.nn.ConstantPad2d((int(center[3, 0] - MOUTH_W / 2),
IMAGE_SIZE - int(center[3, 0] + MOUTH_W / 2),
int(center[3, 1] - MOUTH_H / 2),
IMAGE_SIZE - int(center[3, 1] + MOUTH_H / 2)), padvalue)(mouth[i])
elif region_enm in [2]:
eyel_p = eyel
eyer_p = eyer
nose_p = nose
mouth_p = mouth
if comb_op == 0:
eyes = torch.max(eyel_p, eyer_p)
eye_nose = torch.max(eyes, nose_p)
eye_nose_mouth = torch.max(eye_nose, mouth_p)
eye_nose_mouth_hair = torch.max(hair, eye_nose_mouth)
result = torch.max(bg, eye_nose_mouth_hair)
else:
eyes = torch.min(eyel_p, eyer_p)
eye_nose = torch.min(eyes, nose_p)
eye_nose_mouth = torch.min(eye_nose, mouth_p)
eye_nose_mouth_hair = torch.min(hair, eye_nose_mouth)
result = torch.min(bg, eye_nose_mouth_hair)
return result
def partCombiner3(self, face, hair, maskf, maskh, comb_op=1):
if comb_op == 0:
# use max pooling, pad black etc
padvalue = -1
face = self.masked(face, maskf)
hair = self.masked(hair, maskh)
else:
# use min pooling, pad white etc
padvalue = 1
face = self.addone_with_mask(face, maskf)
hair = self.addone_with_mask(hair, maskh)
if comb_op == 0:
result = torch.max(face, hair)
else:
result = torch.min(face, hair)
return result
def tocv2(ts):
img = (ts.numpy() / 2 + 0.5) * 255
img = img.astype('uint8')
img = np.transpose(img, (1, 2, 0))
img = img[:, :, ::-1] # rgb->bgr
return img
def totor(img):
img = img[:, :, ::-1]
tor = transforms.ToTensor()(img)
tor = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(tor)
return tor
def ContinuityForTest(self, real=0):
# Patch-based
self.get_patches()
self.outputs = self.netRegressor(self.fake_B_patches)
line_continuity = torch.mean(self.outputs)
opt = self.opt
file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch), 'continuity.txt')
message = '%s %.04f' % (self.image_paths[0], line_continuity)
with open(file_name, 'a+') as c_file:
c_file.write(message)
c_file.write('\n')
if real == 1:
self.get_patches_real()
self.outputs2 = self.netRegressor(self.real_B_patches)
line_continuity2 = torch.mean(self.outputs2)
file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch),
'continuity-r.txt')
message = '%s %.04f' % (self.image_paths[0], line_continuity2)
with open(file_name, 'a+') as c_file:
c_file.write(message)
c_file.write('\n')
def getLocalParts(self, fakeAB):
bs, nc, _, _ = fakeAB.shape # dtype torch.float32
ncr = int(nc / self.opt.output_nc)
if self.opt.region_enm in [0, 1]:
ratio = self.opt.fineSize / 256
EYE_H = self.opt.EYE_H * ratio
EYE_W = self.opt.EYE_W * ratio
NOSE_H = self.opt.NOSE_H * ratio
NOSE_W = self.opt.NOSE_W * ratio
MOUTH_H = self.opt.MOUTH_H * ratio
MOUTH_W = self.opt.MOUTH_W * ratio
eyel = torch.ones((bs, nc, int(EYE_H), int(EYE_W))).to(self.device)
eyer = torch.ones((bs, nc, int(EYE_H), int(EYE_W))).to(self.device)
nose = torch.ones((bs, nc, int(NOSE_H), int(NOSE_W))).to(self.device)
mouth = torch.ones((bs, nc, int(MOUTH_H), int(MOUTH_W))).to(self.device)
for i in range(bs):
center = self.center[i]
eyel[i] = fakeAB[i, :, center[0, 1] - EYE_H / 2:center[0, 1] + EYE_H / 2,
center[0, 0] - EYE_W / 2:center[0, 0] + EYE_W / 2]
eyer[i] = fakeAB[i, :, center[1, 1] - EYE_H / 2:center[1, 1] + EYE_H / 2,
center[1, 0] - EYE_W / 2:center[1, 0] + EYE_W / 2]
nose[i] = fakeAB[i, :, center[2, 1] - NOSE_H / 2:center[2, 1] + NOSE_H / 2,
center[2, 0] - NOSE_W / 2:center[2, 0] + NOSE_W / 2]
mouth[i] = fakeAB[i, :, center[3, 1] - MOUTH_H / 2:center[3, 1] + MOUTH_H / 2,
center[3, 0] - MOUTH_W / 2:center[3, 0] + MOUTH_W / 2]
elif self.opt.region_enm in [2]:
eyel = (fakeAB / 2 + 0.5) * self.cmaskel.repeat(1, ncr, 1, 1) * 2 - 1
eyer = (fakeAB / 2 + 0.5) * self.cmasker.repeat(1, ncr, 1, 1) * 2 - 1
nose = (fakeAB / 2 + 0.5) * self.cmask.repeat(1, ncr, 1, 1) * 2 - 1
mouth = (fakeAB / 2 + 0.5) * self.cmaskmo.repeat(1, ncr, 1, 1) * 2 - 1
hair = (fakeAB / 2 + 0.5) * self.mask.repeat(1, ncr, 1, 1) * self.mask2.repeat(1, ncr, 1, 1) * 2 - 1
bg = (fakeAB / 2 + 0.5) * (torch.ones(fakeAB.shape).to(self.device) - self.mask2.repeat(1, ncr, 1, 1)) * 2 - 1
return eyel, eyer, nose, mouth, hair, bg
def getaddw(self, local_name):
addw = 1
if local_name in ['DLEyel', 'DLEyer', 'eyel', 'eyer', 'DLFace', 'face']:
addw = self.opt.addw_eye
elif local_name in ['DLNose', 'nose']:
addw = self.opt.addw_nose
elif local_name in ['DLMouth', 'mouth']:
addw = self.opt.addw_mouth
elif local_name in ['DLHair', 'hair']:
addw = self.opt.addw_hair
elif local_name in ['DLBG', 'bg']:
addw = self.opt.addw_bg
return addw