Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import numpy as np | |
import torch | |
import torchvision.transforms as transforms | |
from torch.utils.data.dataset import Dataset | |
from torch.utils.data.dataloader import DataLoader | |
from torch.utils.data.sampler import WeightedRandomSampler | |
from PIL import Image | |
from sklearn.preprocessing import MinMaxScaler | |
import pickle | |
from .logger import BaseLogger | |
from typing import List, Dict, Union | |
import pandas as pd | |
logger = BaseLogger.get_logger(__name__) | |
class PrivateAugment(torch.nn.Module): | |
""" | |
Augmentation defined privately. | |
Variety of augmentation can be written in this class if necessary. | |
""" | |
# For X-ray photo. | |
xray_augs_list = [ | |
transforms.RandomAffine(degrees=(-3, 3), translate=(0.02, 0.02)), | |
transforms.RandomAdjustSharpness(sharpness_factor=2), | |
transforms.RandomAutocontrast() | |
] | |
class InputDataMixin: | |
""" | |
Class to normalizes input data. | |
""" | |
def _make_scaler(self) -> MinMaxScaler: | |
""" | |
Make scaler to normalize input data by min-max normalization with train data. | |
Returns: | |
MinMaxScaler: scaler | |
""" | |
scaler = MinMaxScaler() | |
_df_train = self.df_source[self.df_source['split'] == 'train'] # should be normalized with min and max of training data | |
_ = scaler.fit(_df_train[self.input_list]) # fit only | |
return scaler | |
def save_scaler(self, save_path :str) -> None: | |
""" | |
Save scaler | |
Args: | |
save_path (str): path for saving scaler. | |
""" | |
#save_scaler_path = Path(save_datetime_dir, 'scaler.pkl') | |
with open(save_path, 'wb') as f: | |
pickle.dump(self.scaler, f) | |
def load_scaler(self, scaler_path :str) -> None: | |
""" | |
Load scaler. | |
Args: | |
scaler_path (str): path to scaler | |
""" | |
with open(scaler_path, 'rb') as f: | |
scaler = pickle.load(f) | |
return scaler | |
def _normalize_inputs(self, df_inputs: pd.DataFrame) -> torch.FloatTensor: | |
""" | |
Normalize inputs. | |
Args: | |
df_inputs (pd.DataFrame): DataFrame of inputs | |
Returns: | |
torch.FloatTensor: normalized inputs | |
Note: | |
After iloc[[idx], index_input_list], pd.DataFrame is obtained. | |
DataFrame fits the input type of self.scaler.transform. | |
However, after normalizing, the shape of inputs_value is (1, N), where N is the number of input values. | |
Since the shape (1, N) is not acceptable when forwarding, convert (1, N) -> (N,) is needed. | |
""" | |
inputs_value = self.scaler.transform(df_inputs).reshape(-1) # np.float64 | |
inputs_value = np.array(inputs_value, dtype=np.float32) # -> np.float32 | |
inputs_value = torch.from_numpy(inputs_value).clone() # -> torch.float32 | |
return inputs_value | |
def _load_input_value_if_mlp(self, idx: int) -> Union[torch.FloatTensor, str]: | |
""" | |
Load input values after converting them into tensor if MLP is used. | |
Args: | |
idx (int): index | |
Returns: | |
Union[torch.Tensor[float], str]: tensor of input values, or empty string | |
""" | |
inputs_value = '' | |
if self.params.mlp is None: | |
return inputs_value | |
index_input_list = [self.col_index_dict[input] for input in self.input_list] | |
_df_inputs = self.df_split.iloc[[idx], index_input_list] | |
inputs_value = self._normalize_inputs( _df_inputs) | |
return inputs_value | |
class ImageMixin: | |
""" | |
Class to normalize and transform image. | |
""" | |
def _make_augmentations(self) -> List: | |
""" | |
Define which augmentation is applied. | |
When training, augmentation is needed for train data only. | |
When test, no need of augmentation. | |
""" | |
_augmentation = [] | |
if (self.params.isTrain) and (self.split == 'train'): | |
if self.params.augmentation == 'xrayaug': | |
_augmentation = PrivateAugment.xray_augs_list | |
elif self.params.augmentation == 'trivialaugwide': | |
_augmentation.append(transforms.TrivialAugmentWide()) | |
elif self.params.augmentation == 'randaug': | |
_augmentation.append(transforms.RandAugment()) | |
else: | |
# ie. self.params.augmentation == 'no': | |
pass | |
_augmentation = transforms.Compose(_augmentation) | |
return _augmentation | |
def _make_transforms(self) -> List: | |
""" | |
Make list of transforms. | |
Returns: | |
list of transforms: image normalization | |
""" | |
_transforms = [] | |
_transforms.append(transforms.ToTensor()) | |
if self.params.normalize_image == 'yes': | |
# transforms.Normalize accepts only Tensor. | |
if self.params.in_channel == 1: | |
_transforms.append(transforms.Normalize(mean=(0.5, ), std=(0.5, ))) | |
else: | |
# ie. self.params.in_channel == 3 | |
_transforms.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) | |
_transforms = transforms.Compose(_transforms) | |
return _transforms | |
def _open_image_in_channel(self, imgpath: str, in_channel: int) -> Image: | |
""" | |
Open image in channel. | |
Args: | |
imgpath (str): path to image | |
in_channel (int): channel, or 1 or 3 | |
Returns: | |
Image: PIL image | |
""" | |
if in_channel == 1: | |
image = Image.open(imgpath).convert('L') # eg. np.array(image).shape = (64, 64) | |
return image | |
else: | |
# ie. self.params.in_channel == 3 | |
image = Image.open(imgpath).convert('RGB') # eg. np.array(image).shape = (64, 64, 3) | |
return image | |
def _load_image_if_cnn(self, idx: int) -> Union[torch.Tensor, str]: | |
""" | |
Load image and convert it to tensor if any of CNN or ViT is used. | |
Args: | |
idx (int): index | |
Returns: | |
Union[torch.Tensor[float], str]: tensor converted from image, or empty string | |
""" | |
image = '' | |
if self.params.net is None: | |
return image | |
imgpath = self.df_split.iat[idx, self.col_index_dict['imgpath']] | |
image = self._open_image_in_channel(imgpath, self.params.in_channel) | |
image = self.augmentation(image) | |
image = self.transform(image) | |
return image | |
class DeepSurvMixin: | |
""" | |
Class to handle required data for deepsurv. | |
""" | |
def _load_periods_if_deepsurv(self, idx: int) -> Union[torch.FloatTensor, str]: | |
""" | |
Return period if deepsurv. | |
Args: | |
idx (int): index | |
Returns: | |
Union[torch.FloatTensor, str]: period, or empty string | |
""" | |
periods = '' | |
if self.params.task != 'deepsurv': | |
return periods | |
assert (self.params.task == 'deepsurv') and (len(self.label_list) == 1), 'Deepsurv cannot work in multi-label.' | |
periods = self.df_split.iat[idx, self.col_index_dict[self.period_name]] # int64 | |
periods = np.array(periods, dtype=np.float32) # -> np.float32 | |
periods = torch.from_numpy(periods).clone() # -> torch.float32 | |
return periods | |
class DataSetWidget(InputDataMixin, ImageMixin, DeepSurvMixin): | |
""" | |
Class for a widget to inherit multiple classes simultaneously. | |
""" | |
pass | |
class LoadDataSet(Dataset, DataSetWidget): | |
""" | |
Dataset for split. | |
""" | |
def __init__( | |
self, | |
params, | |
split: str | |
) -> None: | |
""" | |
Args: | |
params (ParamSet): parameter for model | |
split (str): split | |
""" | |
self.params = params | |
self.df_source = self.params.df_source | |
self.split = split | |
self.input_list = self.params.input_list | |
self.label_list = self.params.label_list | |
if self.params.task == 'deepsurv': | |
self.period_name = self.params.period_name | |
self.df_split = self.df_source[self.df_source['split'] == self.split] | |
self.col_index_dict = {col_name: self.df_split.columns.get_loc(col_name) for col_name in self.df_split.columns} | |
# For input data | |
if self.params.mlp is not None: | |
assert (self.input_list != []), f"input list is empty." | |
if params.isTrain: | |
self.scaler = self._make_scaler() | |
else: | |
# load scaler used at training. | |
self.scaler = self.load_scaler(self.params.scaler_path) | |
# For image | |
if self.params.net is not None: | |
self.augmentation = self._make_augmentations() | |
self.transform = self._make_transforms() | |
def __len__(self) -> int: | |
""" | |
Return length of DataFrame. | |
Returns: | |
int: length of DataFrame | |
""" | |
return len(self.df_split) | |
def _load_label(self, idx: int) -> Dict[str, Union[int, float]]: | |
""" | |
Return labels. | |
If no column of label when csv of external dataset is used, | |
empty dictionary is returned. | |
Args: | |
idx (int): index | |
Returns: | |
Dict[str, Union[int, float]]: dictionary of label name and its value | |
""" | |
# For checking if columns of labels exist when used csv for external dataset. | |
label_list_in_split = list(self.df_split.columns[self.df_split.columns.str.startswith('label')]) | |
label_dict = dict() | |
if label_list_in_split != []: | |
for label_name in self.label_list: | |
label_dict[label_name] = self.df_split.iat[idx, self.col_index_dict[label_name]] | |
else: | |
# no label | |
pass | |
return label_dict | |
def __getitem__(self, idx: int) -> Dict: | |
""" | |
Return data row specified by index. | |
Args: | |
idx (int): index | |
Returns: | |
Dict: dictionary of data to be passed model | |
""" | |
uniqID = self.df_split.iat[idx, self.col_index_dict['uniqID']] | |
group = self.df_split.iat[idx, self.col_index_dict['group']] | |
imgpath = self.df_split.iat[idx, self.col_index_dict['imgpath']] | |
split = self.df_split.iat[idx, self.col_index_dict['split']] | |
inputs_value = self._load_input_value_if_mlp(idx) | |
image = self._load_image_if_cnn(idx) | |
label_dict = self._load_label(idx) | |
periods = self._load_periods_if_deepsurv(idx) | |
_data = { | |
'uniqID': uniqID, | |
'group': group, | |
'imgpath': imgpath, | |
'split': split, | |
'inputs': inputs_value, | |
'image': image, | |
'labels': label_dict, | |
'periods': periods | |
} | |
return _data | |
def _make_sampler(split_data: LoadDataSet) -> WeightedRandomSampler: | |
""" | |
Make sampler. | |
Args: | |
split_data (LoadDataSet): dataset | |
Returns: | |
WeightedRandomSampler: sampler | |
""" | |
_target = [] | |
for _, data in enumerate(split_data): | |
_target.append(list(data['labels'].values())[0]) | |
class_sample_count = np.array([len(np.where(_target == t)[0]) for t in np.unique(_target)]) | |
weight = 1. / class_sample_count | |
samples_weight = np.array([weight[t] for t in _target]) | |
sampler = WeightedRandomSampler(samples_weight, len(samples_weight)) | |
return sampler | |
def create_dataloader( | |
params, | |
split: str = None | |
) -> DataLoader: | |
""" | |
Create data loader ofr split. | |
Args: | |
params (ParamSet): parameter for dataloader | |
split (str): split. Defaults to None. | |
Returns: | |
DataLoader: data loader | |
""" | |
split_data = LoadDataSet(params, split) | |
if params.isTrain: | |
batch_size = params.batch_size | |
shuffle = True | |
else: | |
batch_size = params.test_batch_size | |
shuffle = False | |
if params.sampler == 'yes': | |
assert ((params.task == 'classification') or (params.task == 'deepsurv')), 'Cannot make sampler in regression.' | |
assert (len(params.label_list) == 1), 'Cannot make sampler for multi-label.' | |
shuffle = False | |
sampler = _make_sampler(split_data) | |
else: | |
# When params.sampler == 'no' | |
sampler = None | |
split_loader = DataLoader( | |
dataset=split_data, | |
batch_size=batch_size, | |
shuffle=shuffle, | |
num_workers=0, | |
sampler=sampler | |
) | |
return split_loader | |