interfacegan_pp / models /base_generator.py
ybelkada's picture
commit files
4d6b877
raw
history blame contribute delete
No virus
8.18 kB
# python3.7
"""Contains the base class for generator."""
import os
import sys
import logging
import numpy as np
import torch
from . import model_settings
__all__ = ['BaseGenerator']
def get_temp_logger(logger_name='logger'):
"""Gets a temporary logger.
This logger will print all levels of messages onto the screen.
Args:
logger_name: Name of the logger.
Returns:
A `logging.Logger`.
Raises:
ValueError: If the input `logger_name` is empty.
"""
if not logger_name:
raise ValueError(f'Input `logger_name` should not be empty!')
logger = logging.getLogger(logger_name)
if not logger.hasHandlers():
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter("[%(asctime)s][%(levelname)s] %(message)s")
sh = logging.StreamHandler(stream=sys.stdout)
sh.setLevel(logging.DEBUG)
sh.setFormatter(formatter)
logger.addHandler(sh)
return logger
class BaseGenerator(object):
"""Base class for generator used in GAN variants.
NOTE: The model should be defined with pytorch, and only used for inference.
"""
def __init__(self, model_name, logger=None):
"""Initializes with specific settings.
The model should be registered in `model_settings.py` with proper settings
first. Among them, some attributes are necessary, including:
(1) gan_type: Type of the GAN model.
(2) latent_space_dim: Dimension of the latent space. Should be a tuple.
(3) resolution: Resolution of the synthesis.
(4) min_val: Minimum value of the raw output. (default -1.0)
(5) max_val: Maximum value of the raw output. (default 1.0)
(6) channel_order: Channel order of the output image. (default: `RGB`)
Args:
model_name: Name with which the model is registered.
logger: Logger for recording log messages. If set as `None`, a default
logger, which prints messages from all levels to screen, will be
created. (default: None)
Raises:
AttributeError: If some necessary attributes are missing.
"""
self.model_name = model_name
for key, val in model_settings.MODEL_POOL[model_name].items():
setattr(self, key, val)
self.use_cuda = model_settings.USE_CUDA
self.batch_size = model_settings.MAX_IMAGES_ON_DEVICE
self.logger = logger or get_temp_logger(model_name + '_generator')
self.model = None
self.run_device = 'cuda' if self.use_cuda else 'cpu'
self.cpu_device = 'cpu'
# Check necessary settings.
self.check_attr('gan_type')
self.check_attr('latent_space_dim')
self.check_attr('resolution')
self.min_val = getattr(self, 'min_val', -1.0)
self.max_val = getattr(self, 'max_val', 1.0)
self.output_channels = getattr(self, 'output_channels', 3)
self.channel_order = getattr(self, 'channel_order', 'RGB').upper()
assert self.channel_order in ['RGB', 'BGR']
# Build model and load pre-trained weights.
self.build()
if os.path.isfile(getattr(self, 'model_path', '')):
self.load()
elif os.path.isfile(getattr(self, 'tf_model_path', '')):
self.convert_tf_model()
else:
self.logger.warning(f'No pre-trained model will be loaded!')
# Change to inference mode and GPU mode if needed.
assert self.model
self.model.eval().to(self.run_device)
def check_attr(self, attr_name):
"""Checks the existence of a particular attribute.
Args:
attr_name: Name of the attribute to check.
Raises:
AttributeError: If the target attribute is missing.
"""
if not hasattr(self, attr_name):
raise AttributeError(
f'`{attr_name}` is missing for model `{self.model_name}`!')
def build(self):
"""Builds the graph."""
raise NotImplementedError(f'Should be implemented in derived class!')
def load(self):
"""Loads pre-trained weights."""
raise NotImplementedError(f'Should be implemented in derived class!')
def convert_tf_model(self, test_num=10):
"""Converts models weights from tensorflow version.
Args:
test_num: Number of images to generate for testing whether the conversion
is done correctly. `0` means skipping the test. (default 10)
"""
raise NotImplementedError(f'Should be implemented in derived class!')
def sample(self, num):
"""Samples latent codes randomly.
Args:
num: Number of latent codes to sample. Should be positive.
Returns:
A `numpy.ndarray` as sampled latend codes.
"""
raise NotImplementedError(f'Should be implemented in derived class!')
def preprocess(self, latent_codes):
"""Preprocesses the input latent code if needed.
Args:
latent_codes: The input latent codes for preprocessing.
Returns:
The preprocessed latent codes which can be used as final input for the
generator.
"""
raise NotImplementedError(f'Should be implemented in derived class!')
def easy_sample(self, num):
"""Wraps functions `sample()` and `preprocess()` together."""
return self.preprocess(self.sample(num))
def synthesize(self, latent_codes):
"""Synthesizes images with given latent codes.
NOTE: The latent codes should have already been preprocessed.
Args:
latent_codes: Input latent codes for image synthesis.
Returns:
A dictionary whose values are raw outputs from the generator.
"""
raise NotImplementedError(f'Should be implemented in derived class!')
def get_value(self, tensor):
"""Gets value of a `torch.Tensor`.
Args:
tensor: The input tensor to get value from.
Returns:
A `numpy.ndarray`.
Raises:
ValueError: If the tensor is with neither `torch.Tensor` type or
`numpy.ndarray` type.
"""
if isinstance(tensor, np.ndarray):
return tensor
if isinstance(tensor, torch.Tensor):
return tensor.to(self.cpu_device).detach().numpy()
raise ValueError(f'Unsupported input type `{type(tensor)}`!')
def postprocess(self, images):
"""Postprocesses the output images if needed.
This function assumes the input numpy array is with shape [batch_size,
channel, height, width]. Here, `channel = 3` for color image and
`channel = 1` for grayscale image. The return images are with shape
[batch_size, height, width, channel]. NOTE: The channel order of output
image will always be `RGB`.
Args:
images: The raw output from the generator.
Returns:
The postprocessed images with dtype `numpy.uint8` with range [0, 255].
Raises:
ValueError: If the input `images` are not with type `numpy.ndarray` or not
with shape [batch_size, channel, height, width].
"""
if not isinstance(images, np.ndarray):
raise ValueError(f'Images should be with type `numpy.ndarray`!')
if ('stylegan3' not in self.model_name) and ('stylegan2' not in self.model_name):
images_shape = images.shape
if len(images_shape) != 4 or images_shape[1] not in [1, 3]:
raise ValueError(f'Input should be with shape [batch_size, channel, '
f'height, width], where channel equals to 1 or 3. '
f'But {images_shape} is received!')
images = (images - self.min_val) * 255 / (self.max_val - self.min_val)
images = np.clip(images + 0.5, 0, 255).astype(np.uint8)
images = images.transpose(0, 2, 3, 1)
if self.channel_order == 'BGR':
images = images[:, :, :, ::-1]
return images
def easy_synthesize(self, latent_codes, **kwargs):
"""Wraps functions `synthesize()` and `postprocess()` together."""
outputs = self.synthesize(latent_codes, **kwargs)
if 'image' in outputs:
outputs['image'] = self.postprocess(outputs['image'])
return outputs
def get_batch_inputs(self, latent_codes):
"""Gets batch inputs from a collection of latent codes.
This function will yield at most `self.batch_size` latent_codes at a time.
Args:
latent_codes: The input latent codes for generation. First dimension
should be the total number.
"""
total_num = latent_codes.shape[0]
for i in range(0, total_num, self.batch_size):
yield latent_codes[i:i + self.batch_size]