|
import os |
|
import torch |
|
from collections import OrderedDict |
|
from . import networks |
|
|
|
|
|
class BaseModel(): |
|
|
|
|
|
|
|
@staticmethod |
|
def modify_commandline_options(parser, is_train): |
|
return parser |
|
|
|
def name(self): |
|
return 'BaseModel' |
|
|
|
def initialize(self, opt): |
|
self.opt = opt |
|
self.gpu_ids = opt.gpu_ids |
|
self.gpu_ids_p = opt.gpu_ids_p |
|
self.isTrain = opt.isTrain |
|
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') |
|
self.device_p = torch.device('cuda:{}'.format(self.gpu_ids_p[0])) if self.gpu_ids else torch.device('cpu') |
|
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) |
|
self.auxiliary_dir = os.path.join(opt.checkpoints_dir, opt.auxiliary_root) |
|
if opt.resize_or_crop != 'scale_width': |
|
torch.backends.cudnn.benchmark = True |
|
self.loss_names = [] |
|
self.model_names = [] |
|
self.visual_names = [] |
|
self.image_paths = [] |
|
|
|
def set_input(self, input): |
|
self.input = input |
|
|
|
def forward(self): |
|
pass |
|
|
|
|
|
def setup(self, opt, parser=None): |
|
if self.isTrain: |
|
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] |
|
|
|
if not self.isTrain or opt.continue_train: |
|
self.load_networks(opt.which_epoch) |
|
if len(self.auxiliary_model_names) > 0: |
|
self.load_auxiliary_networks() |
|
self.print_networks(opt.verbose) |
|
|
|
|
|
def eval(self): |
|
for name in self.model_names: |
|
if isinstance(name, str): |
|
net = getattr(self, 'net' + name) |
|
net.eval() |
|
|
|
|
|
|
|
def test(self): |
|
with torch.no_grad(): |
|
self.forward() |
|
|
|
|
|
def get_image_paths(self): |
|
return self.image_paths |
|
|
|
def optimize_parameters(self): |
|
pass |
|
|
|
|
|
def update_learning_rate(self): |
|
for scheduler in self.schedulers: |
|
scheduler.step() |
|
lr = self.optimizers[0].param_groups[0]['lr'] |
|
print('learning rate = %.7f' % lr) |
|
|
|
|
|
def get_current_visuals(self): |
|
visual_ret = OrderedDict() |
|
for name in self.visual_names: |
|
if isinstance(name, str): |
|
visual_ret[name] = getattr(self, name) |
|
return visual_ret |
|
|
|
|
|
def get_current_losses(self): |
|
errors_ret = OrderedDict() |
|
for name in self.loss_names: |
|
if isinstance(name, str): |
|
|
|
errors_ret[name] = float(getattr(self, 'loss_' + name)) |
|
return errors_ret |
|
|
|
|
|
def save_networks(self, which_epoch): |
|
for name in self.model_names: |
|
if isinstance(name, str): |
|
save_filename = '%s_net_%s.pth' % (which_epoch, name) |
|
save_path = os.path.join(self.save_dir, save_filename) |
|
net = getattr(self, 'net' + name) |
|
|
|
if len(self.gpu_ids) > 0 and torch.cuda.is_available(): |
|
torch.save(net.module.cpu().state_dict(), save_path) |
|
net.cuda(self.gpu_ids[0]) |
|
else: |
|
torch.save(net.cpu().state_dict(), save_path) |
|
|
|
def save_networks2(self, which_epoch): |
|
gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch)) |
|
dis_name = os.path.join(self.save_dir, '%s_net_dis.pt' % (which_epoch)) |
|
dict_gen = {} |
|
dict_dis = {} |
|
for name in self.model_names: |
|
if isinstance(name, str): |
|
net = getattr(self, 'net' + name) |
|
|
|
if len(self.gpu_ids) > 0 and torch.cuda.is_available(): |
|
state_dict = net.module.cpu().state_dict() |
|
net.cuda(self.gpu_ids[0]) |
|
else: |
|
state_dict = net.cpu().state_dict() |
|
|
|
if name[0] == 'G': |
|
dict_gen[name] = state_dict |
|
elif name[0] == 'D': |
|
dict_dis[name] = state_dict |
|
else: |
|
save_filename = '%s_net_%s.pth' % (which_epoch, name) |
|
save_path = os.path.join(self.save_dir, save_filename) |
|
torch.save(state_dict, save_path) |
|
if dict_gen: |
|
torch.save(dict_gen, gen_name) |
|
if dict_dis: |
|
torch.save(dict_dis, dis_name) |
|
|
|
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): |
|
key = keys[i] |
|
if i + 1 == len(keys): |
|
if module.__class__.__name__.startswith('InstanceNorm') and \ |
|
(key == 'running_mean' or key == 'running_var'): |
|
if getattr(module, key) is None: |
|
state_dict.pop('.'.join(keys)) |
|
if module.__class__.__name__.startswith('InstanceNorm') and \ |
|
(key == 'num_batches_tracked'): |
|
state_dict.pop('.'.join(keys)) |
|
else: |
|
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) |
|
|
|
|
|
def load_networks(self, which_epoch): |
|
gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch)) |
|
if os.path.exists(gen_name): |
|
self.load_networks2(which_epoch) |
|
return |
|
for name in self.model_names: |
|
if isinstance(name, str): |
|
load_filename = '%s_net_%s.pth' % (which_epoch, name) |
|
load_path = os.path.join(self.save_dir, load_filename) |
|
net = getattr(self, 'net' + name) |
|
if isinstance(net, torch.nn.DataParallel): |
|
net = net.module |
|
print('loading the model from %s' % load_path) |
|
|
|
|
|
state_dict = torch.load(load_path, map_location=str(self.device)) |
|
if hasattr(state_dict, '_metadata'): |
|
del state_dict._metadata |
|
|
|
|
|
for key in list(state_dict.keys()): |
|
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) |
|
net.load_state_dict(state_dict) |
|
|
|
def load_networks2(self, which_epoch): |
|
gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch)) |
|
gen_state_dict = torch.load(gen_name, map_location=str(self.device)) |
|
if self.isTrain and self.opt.model != 'apdrawing_style_nogan': |
|
dis_name = os.path.join(self.save_dir, '%s_net_dis.pt' % (which_epoch)) |
|
dis_state_dict = torch.load(dis_name, map_location=str(self.device)) |
|
for name in self.model_names: |
|
if isinstance(name, str): |
|
net = getattr(self, 'net' + name) |
|
if isinstance(net, torch.nn.DataParallel): |
|
net = net.module |
|
if name[0] == 'G': |
|
print('loading the model %s from %s' % (name, gen_name)) |
|
state_dict = gen_state_dict[name] |
|
elif name[0] == 'D': |
|
print('loading the model %s from %s' % (name, gen_name)) |
|
state_dict = dis_state_dict[name] |
|
|
|
if hasattr(state_dict, '_metadata'): |
|
del state_dict._metadata |
|
|
|
for key in list(state_dict.keys()): |
|
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) |
|
net.load_state_dict(state_dict) |
|
|
|
|
|
def load_auxiliary_networks(self): |
|
for name in self.auxiliary_model_names: |
|
if isinstance(name, str): |
|
if 'AE' in name and self.opt.ae_small: |
|
load_filename = '%s_net_%s_small.pth' % ('latest', name) |
|
elif 'Regressor' in name: |
|
load_filename = '%s_net_%s%d.pth' % ('latest', name, self.opt.regarch) |
|
else: |
|
load_filename = '%s_net_%s.pth' % ('latest', name) |
|
load_path = os.path.join(self.auxiliary_dir, load_filename) |
|
net = getattr(self, 'net' + name) |
|
if isinstance(net, torch.nn.DataParallel): |
|
net = net.module |
|
print('loading the model from %s' % load_path) |
|
|
|
|
|
if name in ['DT1', 'DT2', 'Line1', 'Line2', 'Continuity1', 'Continuity2', 'Regressor', 'Regressorhair', |
|
'Regressorface']: |
|
state_dict = torch.load(load_path, map_location=str(self.device_p)) |
|
else: |
|
state_dict = torch.load(load_path, map_location=str(self.device)) |
|
if hasattr(state_dict, '_metadata'): |
|
del state_dict._metadata |
|
|
|
|
|
for key in list(state_dict.keys()): |
|
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) |
|
net.load_state_dict(state_dict) |
|
|
|
|
|
def print_networks(self, verbose): |
|
print('---------- Networks initialized -------------') |
|
for name in self.model_names: |
|
if isinstance(name, str): |
|
net = getattr(self, 'net' + name) |
|
num_params = 0 |
|
for param in net.parameters(): |
|
num_params += param.numel() |
|
if verbose: |
|
print(net) |
|
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) |
|
print('-----------------------------------------------') |
|
|
|
|
|
def set_requires_grad(self, nets, requires_grad=False): |
|
if not isinstance(nets, list): |
|
nets = [nets] |
|
for net in nets: |
|
if net is not None: |
|
for param in net.parameters(): |
|
param.requires_grad = requires_grad |
|
|
|
|
|
def inverse_mask(self, mask): |
|
return torch.ones(mask.shape).to(self.device) - mask |
|
|
|
def masked(self, A, mask): |
|
return (A / 2 + 0.5) * mask * 2 - 1 |
|
|
|
def add_with_mask(self, A, B, mask): |
|
return ((A / 2 + 0.5) * mask + (B / 2 + 0.5) * (torch.ones(mask.shape).to(self.device) - mask)) * 2 - 1 |
|
|
|
def addone_with_mask(self, A, mask): |
|
return ((A / 2 + 0.5) * mask + (torch.ones(mask.shape).to(self.device) - mask)) * 2 - 1 |
|
|
|
def partCombiner(self, eyel, eyer, nose, mouth, average_pos=False, comb_op=1, region_enm=0, cmaskel=None, |
|
cmasker=None, cmaskno=None, cmaskmo=None): |
|
''' |
|
x y |
|
100.571 123.429 |
|
155.429 123.429 |
|
128.000 155.886 |
|
103.314 185.417 |
|
152.686 185.417 |
|
this is the mean locaiton of 5 landmarks (for 256x256) |
|
Pad2d Left,Right,Top,Down |
|
''' |
|
if comb_op == 0: |
|
|
|
padvalue = -1 |
|
if region_enm in [1, 2]: |
|
eyel = eyel * cmaskel |
|
eyer = eyer * cmasker |
|
nose = nose * cmaskno |
|
mouth = mouth * cmaskmo |
|
else: |
|
|
|
padvalue = 1 |
|
if region_enm in [1, 2]: |
|
eyel = self.addone_with_mask(eyel, cmaskel) |
|
eyer = self.addone_with_mask(eyer, cmasker) |
|
nose = self.addone_with_mask(nose, cmaskno) |
|
mouth = self.addone_with_mask(mouth, cmaskmo) |
|
if region_enm in [0, 1]: |
|
IMAGE_SIZE = self.opt.fineSize |
|
ratio = IMAGE_SIZE / 256 |
|
EYE_W = self.opt.EYE_W * ratio |
|
EYE_H = self.opt.EYE_H * ratio |
|
NOSE_W = self.opt.NOSE_W * ratio |
|
NOSE_H = self.opt.NOSE_H * ratio |
|
MOUTH_W = self.opt.MOUTH_W * ratio |
|
MOUTH_H = self.opt.MOUTH_H * ratio |
|
bs, nc, _, _ = eyel.shape |
|
eyel_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device) |
|
eyer_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device) |
|
nose_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device) |
|
mouth_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device) |
|
for i in range(bs): |
|
if not average_pos: |
|
center = self.center[i] |
|
else: |
|
center = torch.tensor([[101, 123 - 4], [155, 123 - 4], [128, 156 - NOSE_H / 2 + 16], [128, 185]]) |
|
eyel_p[i] = torch.nn.ConstantPad2d((int(center[0, 0] - EYE_W / 2 - 1), |
|
int(IMAGE_SIZE - (center[0, 0] + EYE_W / 2 - 1)), |
|
int(center[0, 1] - EYE_H / 2 - 1), |
|
int(IMAGE_SIZE - (center[0, 1] + EYE_H / 2 - 1))), -1)(eyel[i]) |
|
eyer_p[i] = torch.nn.ConstantPad2d((int(center[1, 0] - EYE_W / 2 - 1), |
|
int(IMAGE_SIZE - (center[1, 0] + EYE_W / 2 - 1)), |
|
int(center[1, 1] - EYE_H / 2 - 1), |
|
int(IMAGE_SIZE - (center[1, 1] + EYE_H / 2 - 1))), -1)(eyer[i]) |
|
nose_p[i] = torch.nn.ConstantPad2d((int(center[2, 0] - NOSE_W / 2 - 1), |
|
int(IMAGE_SIZE - (center[2, 0] + NOSE_W / 2 - 1)), |
|
int(center[2, 1] - NOSE_H / 2 - 1), |
|
int(IMAGE_SIZE - (center[2, 1] + NOSE_H / 2 - 1))), -1)(nose[i]) |
|
mouth_p[i] = torch.nn.ConstantPad2d((int(center[3, 0] - MOUTH_W / 2 - 1), |
|
int(IMAGE_SIZE - (center[3, 0] + MOUTH_W / 2 - 1)), |
|
int(center[3, 1] - MOUTH_H / 2 - 1), |
|
int(IMAGE_SIZE - (center[3, 1] + MOUTH_H / 2 - 1))), -1)(mouth[i]) |
|
elif region_enm in [2]: |
|
eyel_p = eyel |
|
eyer_p = eyer |
|
nose_p = nose |
|
mouth_p = mouth |
|
if comb_op == 0: |
|
|
|
eyes = torch.max(eyel_p, eyer_p) |
|
eye_nose = torch.max(eyes, nose_p) |
|
result = torch.max(eye_nose, mouth_p) |
|
else: |
|
|
|
eyes = torch.min(eyel_p, eyer_p) |
|
eye_nose = torch.min(eyes, nose_p) |
|
result = torch.min(eye_nose, mouth_p) |
|
return result |
|
|
|
def partCombiner2(self, eyel, eyer, nose, mouth, hair, mask, comb_op=1, region_enm=0, cmaskel=None, cmasker=None, |
|
cmaskno=None, cmaskmo=None): |
|
if comb_op == 0: |
|
|
|
padvalue = -1 |
|
hair = self.masked(hair, mask) |
|
if region_enm in [1, 2]: |
|
eyel = eyel * cmaskel |
|
eyer = eyer * cmasker |
|
nose = nose * cmaskno |
|
mouth = mouth * cmaskmo |
|
else: |
|
|
|
padvalue = 1 |
|
hair = self.addone_with_mask(hair, mask) |
|
if region_enm in [1, 2]: |
|
eyel = self.addone_with_mask(eyel, cmaskel) |
|
eyer = self.addone_with_mask(eyer, cmasker) |
|
nose = self.addone_with_mask(nose, cmaskno) |
|
mouth = self.addone_with_mask(mouth, cmaskmo) |
|
if region_enm in [0, 1]: |
|
IMAGE_SIZE = self.opt.fineSize |
|
ratio = IMAGE_SIZE / 256 |
|
EYE_W = self.opt.EYE_W * ratio |
|
EYE_H = self.opt.EYE_H * ratio |
|
NOSE_W = self.opt.NOSE_W * ratio |
|
NOSE_H = self.opt.NOSE_H * ratio |
|
MOUTH_W = self.opt.MOUTH_W * ratio |
|
MOUTH_H = self.opt.MOUTH_H * ratio |
|
bs, nc, _, _ = eyel.shape |
|
eyel_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device) |
|
eyer_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device) |
|
nose_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device) |
|
mouth_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device) |
|
for i in range(bs): |
|
center = self.center[i] |
|
eyel_p[i] = torch.nn.ConstantPad2d((center[0, 0] - EYE_W / 2, IMAGE_SIZE - (center[0, 0] + EYE_W / 2), |
|
center[0, 1] - EYE_H / 2, IMAGE_SIZE - (center[0, 1] + EYE_H / 2)), |
|
padvalue)(eyel[i]) |
|
eyer_p[i] = torch.nn.ConstantPad2d((center[1, 0] - EYE_W / 2, IMAGE_SIZE - (center[1, 0] + EYE_W / 2), |
|
center[1, 1] - EYE_H / 2, IMAGE_SIZE - (center[1, 1] + EYE_H / 2)), |
|
padvalue)(eyer[i]) |
|
nose_p[i] = torch.nn.ConstantPad2d((center[2, 0] - NOSE_W / 2, IMAGE_SIZE - (center[2, 0] + NOSE_W / 2), |
|
center[2, 1] - NOSE_H / 2, |
|
IMAGE_SIZE - (center[2, 1] + NOSE_H / 2)), padvalue)(nose[i]) |
|
mouth_p[i] = torch.nn.ConstantPad2d((center[3, 0] - MOUTH_W / 2, |
|
IMAGE_SIZE - (center[3, 0] + MOUTH_W / 2), |
|
center[3, 1] - MOUTH_H / 2, |
|
IMAGE_SIZE - (center[3, 1] + MOUTH_H / 2)), padvalue)(mouth[i]) |
|
elif region_enm in [2]: |
|
eyel_p = eyel |
|
eyer_p = eyer |
|
nose_p = nose |
|
mouth_p = mouth |
|
if comb_op == 0: |
|
|
|
eyes = torch.max(eyel_p, eyer_p) |
|
eye_nose = torch.max(eyes, nose_p) |
|
eye_nose_mouth = torch.max(eye_nose, mouth_p) |
|
result = torch.max(hair, eye_nose_mouth) |
|
else: |
|
|
|
eyes = torch.min(eyel_p, eyer_p) |
|
eye_nose = torch.min(eyes, nose_p) |
|
eye_nose_mouth = torch.min(eye_nose, mouth_p) |
|
result = torch.min(hair, eye_nose_mouth) |
|
return result |
|
|
|
def partCombiner2_bg(self, eyel, eyer, nose, mouth, hair, bg, maskh, maskb, comb_op=1, region_enm=0, cmaskel=None, |
|
cmasker=None, cmaskno=None, cmaskmo=None): |
|
if comb_op == 0: |
|
|
|
padvalue = -1 |
|
hair = self.masked(hair, maskh) |
|
bg = self.masked(bg, maskb) |
|
if region_enm in [1, 2]: |
|
eyel = eyel * cmaskel |
|
eyer = eyer * cmasker |
|
nose = nose * cmaskno |
|
mouth = mouth * cmaskmo |
|
else: |
|
|
|
padvalue = 1 |
|
hair = self.addone_with_mask(hair, maskh) |
|
bg = self.addone_with_mask(bg, maskb) |
|
if region_enm in [1, 2]: |
|
eyel = self.addone_with_mask(eyel, cmaskel) |
|
eyer = self.addone_with_mask(eyer, cmasker) |
|
nose = self.addone_with_mask(nose, cmaskno) |
|
mouth = self.addone_with_mask(mouth, cmaskmo) |
|
if region_enm in [0, 1]: |
|
IMAGE_SIZE = self.opt.fineSize |
|
ratio = IMAGE_SIZE / 256 |
|
EYE_W = self.opt.EYE_W * ratio |
|
EYE_H = self.opt.EYE_H * ratio |
|
NOSE_W = self.opt.NOSE_W * ratio |
|
NOSE_H = self.opt.NOSE_H * ratio |
|
MOUTH_W = self.opt.MOUTH_W * ratio |
|
MOUTH_H = self.opt.MOUTH_H * ratio |
|
bs, nc, _, _ = eyel.shape |
|
eyel_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device) |
|
eyer_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device) |
|
nose_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device) |
|
mouth_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device) |
|
for i in range(bs): |
|
center = self.center[i] |
|
eyel_p[i] = torch.nn.ConstantPad2d((int(center[0, 0] - EYE_W / 2), |
|
IMAGE_SIZE - int(center[0, 0] + EYE_W / 2), |
|
int(center[0, 1] - EYE_H / 2), |
|
IMAGE_SIZE - int(center[0, 1] + EYE_H / 2)), padvalue)(eyel[i]) |
|
eyer_p[i] = torch.nn.ConstantPad2d((int(center[1, 0] - EYE_W / 2), |
|
IMAGE_SIZE - int(center[1, 0] + EYE_W / 2), |
|
int(center[1, 1] - EYE_H / 2), |
|
IMAGE_SIZE - int(center[1, 1] + EYE_H / 2)), padvalue)(eyer[i]) |
|
nose_p[i] = torch.nn.ConstantPad2d((int(center[2, 0] - NOSE_W / 2), |
|
IMAGE_SIZE - int(center[2, 0] + NOSE_W / 2), |
|
int(center[2, 1] - NOSE_H / 2), |
|
IMAGE_SIZE - int(center[2, 1] + NOSE_H / 2)), padvalue)(nose[i]) |
|
mouth_p[i] = torch.nn.ConstantPad2d((int(center[3, 0] - MOUTH_W / 2), |
|
IMAGE_SIZE - int(center[3, 0] + MOUTH_W / 2), |
|
int(center[3, 1] - MOUTH_H / 2), |
|
IMAGE_SIZE - int(center[3, 1] + MOUTH_H / 2)), padvalue)(mouth[i]) |
|
elif region_enm in [2]: |
|
eyel_p = eyel |
|
eyer_p = eyer |
|
nose_p = nose |
|
mouth_p = mouth |
|
if comb_op == 0: |
|
eyes = torch.max(eyel_p, eyer_p) |
|
eye_nose = torch.max(eyes, nose_p) |
|
eye_nose_mouth = torch.max(eye_nose, mouth_p) |
|
eye_nose_mouth_hair = torch.max(hair, eye_nose_mouth) |
|
result = torch.max(bg, eye_nose_mouth_hair) |
|
else: |
|
eyes = torch.min(eyel_p, eyer_p) |
|
eye_nose = torch.min(eyes, nose_p) |
|
eye_nose_mouth = torch.min(eye_nose, mouth_p) |
|
eye_nose_mouth_hair = torch.min(hair, eye_nose_mouth) |
|
result = torch.min(bg, eye_nose_mouth_hair) |
|
return result |
|
|
|
def partCombiner3(self, face, hair, maskf, maskh, comb_op=1): |
|
if comb_op == 0: |
|
|
|
padvalue = -1 |
|
face = self.masked(face, maskf) |
|
hair = self.masked(hair, maskh) |
|
else: |
|
|
|
padvalue = 1 |
|
face = self.addone_with_mask(face, maskf) |
|
hair = self.addone_with_mask(hair, maskh) |
|
if comb_op == 0: |
|
result = torch.max(face, hair) |
|
else: |
|
result = torch.min(face, hair) |
|
return result |
|
|
|
def tocv2(ts): |
|
img = (ts.numpy() / 2 + 0.5) * 255 |
|
img = img.astype('uint8') |
|
img = np.transpose(img, (1, 2, 0)) |
|
img = img[:, :, ::-1] |
|
return img |
|
|
|
def totor(img): |
|
img = img[:, :, ::-1] |
|
tor = transforms.ToTensor()(img) |
|
tor = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(tor) |
|
return tor |
|
|
|
def ContinuityForTest(self, real=0): |
|
|
|
self.get_patches() |
|
self.outputs = self.netRegressor(self.fake_B_patches) |
|
line_continuity = torch.mean(self.outputs) |
|
opt = self.opt |
|
file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch), 'continuity.txt') |
|
message = '%s %.04f' % (self.image_paths[0], line_continuity) |
|
with open(file_name, 'a+') as c_file: |
|
c_file.write(message) |
|
c_file.write('\n') |
|
if real == 1: |
|
self.get_patches_real() |
|
self.outputs2 = self.netRegressor(self.real_B_patches) |
|
line_continuity2 = torch.mean(self.outputs2) |
|
file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch), |
|
'continuity-r.txt') |
|
message = '%s %.04f' % (self.image_paths[0], line_continuity2) |
|
with open(file_name, 'a+') as c_file: |
|
c_file.write(message) |
|
c_file.write('\n') |
|
|
|
def getLocalParts(self, fakeAB): |
|
bs, nc, _, _ = fakeAB.shape |
|
ncr = int(nc / self.opt.output_nc) |
|
if self.opt.region_enm in [0, 1]: |
|
ratio = self.opt.fineSize / 256 |
|
EYE_H = self.opt.EYE_H * ratio |
|
EYE_W = self.opt.EYE_W * ratio |
|
NOSE_H = self.opt.NOSE_H * ratio |
|
NOSE_W = self.opt.NOSE_W * ratio |
|
MOUTH_H = self.opt.MOUTH_H * ratio |
|
MOUTH_W = self.opt.MOUTH_W * ratio |
|
eyel = torch.ones((bs, nc, int(EYE_H), int(EYE_W))).to(self.device) |
|
eyer = torch.ones((bs, nc, int(EYE_H), int(EYE_W))).to(self.device) |
|
nose = torch.ones((bs, nc, int(NOSE_H), int(NOSE_W))).to(self.device) |
|
mouth = torch.ones((bs, nc, int(MOUTH_H), int(MOUTH_W))).to(self.device) |
|
for i in range(bs): |
|
center = self.center[i] |
|
eyel[i] = fakeAB[i, :, center[0, 1] - EYE_H / 2:center[0, 1] + EYE_H / 2, |
|
center[0, 0] - EYE_W / 2:center[0, 0] + EYE_W / 2] |
|
eyer[i] = fakeAB[i, :, center[1, 1] - EYE_H / 2:center[1, 1] + EYE_H / 2, |
|
center[1, 0] - EYE_W / 2:center[1, 0] + EYE_W / 2] |
|
nose[i] = fakeAB[i, :, center[2, 1] - NOSE_H / 2:center[2, 1] + NOSE_H / 2, |
|
center[2, 0] - NOSE_W / 2:center[2, 0] + NOSE_W / 2] |
|
mouth[i] = fakeAB[i, :, center[3, 1] - MOUTH_H / 2:center[3, 1] + MOUTH_H / 2, |
|
center[3, 0] - MOUTH_W / 2:center[3, 0] + MOUTH_W / 2] |
|
elif self.opt.region_enm in [2]: |
|
eyel = (fakeAB / 2 + 0.5) * self.cmaskel.repeat(1, ncr, 1, 1) * 2 - 1 |
|
eyer = (fakeAB / 2 + 0.5) * self.cmasker.repeat(1, ncr, 1, 1) * 2 - 1 |
|
nose = (fakeAB / 2 + 0.5) * self.cmask.repeat(1, ncr, 1, 1) * 2 - 1 |
|
mouth = (fakeAB / 2 + 0.5) * self.cmaskmo.repeat(1, ncr, 1, 1) * 2 - 1 |
|
hair = (fakeAB / 2 + 0.5) * self.mask.repeat(1, ncr, 1, 1) * self.mask2.repeat(1, ncr, 1, 1) * 2 - 1 |
|
bg = (fakeAB / 2 + 0.5) * (torch.ones(fakeAB.shape).to(self.device) - self.mask2.repeat(1, ncr, 1, 1)) * 2 - 1 |
|
return eyel, eyer, nose, mouth, hair, bg |
|
|
|
def getaddw(self, local_name): |
|
addw = 1 |
|
if local_name in ['DLEyel', 'DLEyer', 'eyel', 'eyer', 'DLFace', 'face']: |
|
addw = self.opt.addw_eye |
|
elif local_name in ['DLNose', 'nose']: |
|
addw = self.opt.addw_nose |
|
elif local_name in ['DLMouth', 'mouth']: |
|
addw = self.opt.addw_mouth |
|
elif local_name in ['DLHair', 'hair']: |
|
addw = self.opt.addw_hair |
|
elif local_name in ['DLBG', 'bg']: |
|
addw = self.opt.addw_bg |
|
return addw |