#!/usr/bin/env python # -*- coding: utf-8 -*- from pathlib import Path import copy from abc import ABC, abstractmethod import torch import torch.nn as nn from .component import create_net from .logger import BaseLogger from lib import ParamSet from typing import List, Dict, Tuple, Union # Alias of typing # eg. {'labels': {'label_A: torch.Tensor([0, 1, ...]), ...}} LabelDict = Dict[str, Dict[str, Union[torch.IntTensor, torch.FloatTensor]]] logger = BaseLogger.get_logger(__name__) class BaseModel(ABC): """ Class to construct model. This class is the base class to construct model. """ def __init__(self, params: ParamSet) -> None: """ Class to define Model Args: param (ParamSet): parameters """ self.params = params self.device = self.params.device self.network = create_net( mlp=self.params.mlp, net=self.params.net, num_outputs_for_label=self.params.num_outputs_for_label, mlp_num_inputs=self.params.mlp_num_inputs, in_channel=self.params.in_channel, vit_image_size=self.params.vit_image_size, pretrained=self.params.pretrained ) self.network.to(self.device) # variables to keep temporary best_weight and best_epoch self.acting_best_weight = None self.acting_best_epoch = None def train(self) -> None: """ Make network training mode. """ self.network.train() def eval(self) -> None: """ Make network evaluation mode. """ self.network.eval() @abstractmethod def set_data( self, data: Dict ) -> Tuple[ Dict[str, torch.FloatTensor], Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] ]: raise NotImplementedError def store_weight(self, at_epoch: int = None) -> None: """ Store weight and epoch number when it is saved. Args: at_epoch (int): epoch number when save weight """ self.acting_best_epoch = at_epoch _network = copy.deepcopy(self.network) if hasattr(_network, 'module'): # When DataParallel used, move weight to CPU. self.acting_best_weight = copy.deepcopy(_network.module.to(torch.device('cpu')).state_dict()) else: self.acting_best_weight = copy.deepcopy(_network.state_dict()) def save_weight(self, save_datetime_dir: str, as_best: bool = None) -> None: """ Save weight. Args: save_datetime_dir (str): save_datetime_dir as_best (bool): True if weight is saved as best, otherwise False. Defaults to None. """ save_dir = Path(save_datetime_dir, 'weights') save_dir.mkdir(parents=True, exist_ok=True) save_name = 'weight_epoch-' + str(self.acting_best_epoch).zfill(3) + '.pt' save_path = Path(save_dir, save_name) if as_best: save_name_as_best = 'weight_epoch-' + str(self.acting_best_epoch).zfill(3) + '_best' + '.pt' save_path_as_best = Path(save_dir, save_name_as_best) if save_path.exists(): # Check if best weight already saved. If exists, rename with '_best' save_path.rename(save_path_as_best) else: torch.save(self.acting_best_weight, save_path_as_best) else: save_name = 'weight_epoch-' + str(self.acting_best_epoch).zfill(3) + '.pt' torch.save(self.acting_best_weight, save_path) def load_weight(self, weight_path: Path) -> None: """ Load wight from weight_path. Args: weight_path (Path): path to weight """ logger.info(f"Load weight: {weight_path}.\n") weight = torch.load(weight_path) self.network.load_state_dict(weight) class ModelMixin: def to_gpu(self, gpu_ids: List[int]) -> None: """ Make model compute on the GPU. Args: gpu_ids (List[int]): GPU ids """ if gpu_ids != []: assert torch.cuda.is_available(), 'No available GPU on this machine.' self.network = nn.DataParallel(self.network, device_ids=gpu_ids) def init_network(self) -> None: """ Initialize network. This method is used at test to reset the current weight by redefining network. """ self.network = create_net( mlp=self.params.mlp, net=self.params.net, num_outputs_for_label=self.params.num_outputs_for_label, mlp_num_inputs=self.params.mlp_num_inputs, in_channel=self.params.in_channel, vit_image_size=self.params.vit_image_size, pretrained=self.params.pretrained ) self.network.to(self.device) class ModelWidget(BaseModel, ModelMixin): """ Class for a widget to inherit multiple classes simultaneously """ pass class MLPModel(ModelWidget): """ Class for MLP model """ def __init__(self, params: ParamSet) -> None: """ Args: params: (ParamSet): parameters """ super().__init__(params) def set_data( self, data: Dict ) -> Tuple[ Dict[str, torch.FloatTensor], Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] ]: """ Unpack data for forwarding of MLP and calculating loss by passing them to device. When deepsurv, period and network are also returned. Args: data (Dict): dictionary of data Returns: Tuple[ Dict[str, torch.FloatTensor], Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] ]: input of model and data for calculating loss. eg. ([inputs], [labels]), or ([inputs], [labels, periods, network]) when deepsurv """ in_data = {'inputs': data['inputs'].to(self.device)} labels = {'labels': {label_name: label.to(self.device) for label_name, label in data['labels'].items()}} if not any(data['periods']): return in_data, labels # When deepsurv labels = { **labels, **{'periods': data['periods'].to(self.device), 'network': self.network.to(self.device)} } return in_data, labels def __call__(self, in_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Forward Args: in_data (Dict[str, torch.Tensor]): data to be input into model Returns: Dict[str, torch.Tensor]: output """ inputs = in_data['inputs'] output = self.network(inputs) return output class CVModel(ModelWidget): """ Class for CNN or ViT model """ def __init__(self, params: ParamSet) -> None: """ Args: params: (ParamSet): parameters """ super().__init__(params) def set_data( self, data: Dict ) -> Tuple[ Dict[str, torch.FloatTensor], Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] ]: """ Unpack data for forwarding of CNN or ViT and calculating loss by passing them to device. When deepsurv, period and network are also returned. Args: data (Dict): dictionary of data Returns: Tuple[ Dict[str, torch.FloatTensor], Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] ]: input of model and data for calculating loss. eg. ([image], [labels]), or ([image], [labels, periods, network]) when deepsurv """ in_data = {'image': data['image'].to(self.device)} labels = {'labels': {label_name: label.to(self.device) for label_name, label in data['labels'].items()}} if not any(data['periods']): return in_data, labels # When deepsurv labels = { **labels, **{'periods': data['periods'].to(self.device), 'network': self.network.to(self.device)} } return in_data, labels def __call__(self, in_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Forward Args: in_data (Dict[str, torch.Tensor]): data to be input into model Returns: Dict[str, torch.Tensor]: output """ image = in_data['image'] output = self.network(image) return output class FusionModel(ModelWidget): """ Class for MLP+CNN or MLP+ViT model. """ def __init__(self, params: ParamSet) -> None: """ Args: params: (ParamSet): parameters """ super().__init__(params) def set_data( self, data: Dict ) -> Tuple[ Dict[str, torch.FloatTensor], Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] ]: """ Unpack data for forwarding of MLP+CNN or MLP+ViT and calculating loss by passing them to device. When deepsurv, period and network are also returned. Args: data (Dict): dictionary of data Returns: Tuple[ Dict[str, torch.FloatTensor], Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] ]: input of model and data for calculating loss. eg. ([inputs, image], [labels]), or ([inputs, image], [labels, periods, network]) when deepsurv """ in_data = { 'inputs': data['inputs'].to(self.device), 'image': data['image'].to(self.device) } labels = {'labels': {label_name: label.to(self.device) for label_name, label in data['labels'].items()}} if not any(data['periods']): return in_data, labels # When deepsurv labels = { **labels, **{'periods': data['periods'].to(self.device), 'network': self.network.to(self.device)} } return in_data, labels def __call__(self, in_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Forward Args: in_data (Dict[str, torch.Tensor]): data to be input into model Returns: Dict[str, torch.Tensor]: output """ inputs = in_data['inputs'] image = in_data['image'] output = self.network(inputs, image) return output def create_model(params: ParamSet) -> nn.Module: """ Construct model. Args: params (ParamSet): parameters Returns: nn.Module: model """ _isMLPModel = (params.mlp is not None) and (params.net is None) _isCVModel = (params.mlp is None) and (params.net is not None) _isFusionModel = (params.mlp is not None) and (params.net is not None) if _isMLPModel: return MLPModel(params) elif _isCVModel: return CVModel(params) elif _isFusionModel: return FusionModel(params) else: raise ValueError(f"Invalid model type: mlp={params.mlp}, net={params.net}.")