Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import torch | |
import torch.nn as nn | |
from typing import Dict, Union | |
# Alias of typing | |
# eg. {'labels': {'label_A: torch.Tensor([0, 1, ...]), ...}} | |
LabelDict = Dict[str, Dict[str, Union[torch.IntTensor, torch.FloatTensor]]] | |
class RMSELoss(nn.Module): | |
""" | |
Class to calculate RMSE. | |
""" | |
def __init__(self, eps: float = 1e-7) -> None: | |
""" | |
Args: | |
eps (float, optional): value to avoid 0. Defaults to 1e-7. | |
""" | |
super().__init__() | |
self.mse = nn.MSELoss() | |
self.eps = eps | |
def forward(self, yhat: float, y: float) -> torch.FloatTensor: | |
""" | |
Calculate RMSE. | |
Args: | |
yhat (float): prediction value | |
y (float): ground truth value | |
Returns: | |
float: RMSE | |
""" | |
_loss = self.mse(yhat, y) + self.eps | |
return torch.sqrt(_loss) | |
class Regularization: | |
""" | |
Class to calculate regularization loss. | |
Args: | |
object (object): object | |
""" | |
def __init__(self, order: int, weight_decay: float) -> None: | |
""" | |
The initialization of Regularization class. | |
Args: | |
order: (int) norm order number | |
weight_decay: (float) weight decay rate | |
""" | |
super().__init__() | |
self.order = order | |
self.weight_decay = weight_decay | |
def __call__(self, network: nn.Module) -> torch.FloatTensor: | |
"""" | |
Calculates regularization(self.order) loss for network. | |
Args: | |
model: (torch.nn.Module object) | |
Returns: | |
torch.FloatTensor: the regularization(self.order) loss | |
""" | |
reg_loss = 0 | |
for name, w in network.named_parameters(): | |
if 'weight' in name: | |
reg_loss = reg_loss + torch.norm(w, p=self.order) | |
reg_loss = self.weight_decay * reg_loss | |
return reg_loss | |
class NegativeLogLikelihood(nn.Module): | |
""" | |
Class to calculate RMSE. | |
""" | |
def __init__(self, device: torch.device) -> None: | |
""" | |
Args: | |
device (torch.device): device | |
""" | |
super().__init__() | |
self.L2_reg = 0.05 | |
self.reg = Regularization(order=2, weight_decay=self.L2_reg) | |
self.device = device | |
def forward( | |
self, | |
output: torch.FloatTensor, | |
label: torch.IntTensor, | |
periods: torch.FloatTensor, | |
network: nn.Module | |
) -> torch.FloatTensor: | |
""" | |
Calculates Negative Log Likelihood. | |
Args: | |
output (torch.FloatTensor): prediction value, ie risk prediction | |
label (torch.IntTensor): occurrence of event | |
periods (torch.FloatTensor): period | |
network (nn.Network): network | |
Returns: | |
torch.FloatTensor: Negative Log Likelihood | |
""" | |
mask = torch.ones(periods.shape[0], periods.shape[0]).to(self.device) # output and mask should be on the same device. | |
mask[(periods.T - periods) > 0] = 0 | |
_loss = torch.exp(output) * mask | |
# Note: torch.sum(_loss, dim=0) possibly returns nan, in particular MLP. | |
_loss = torch.sum(_loss, dim=0) / torch.sum(mask, dim=0) | |
_loss = torch.log(_loss).reshape(-1, 1) | |
num_occurs = torch.sum(label) | |
if num_occurs.item() == 0.0: | |
loss = torch.tensor([1e-7], requires_grad=True).to(self.device) # To avoid zero division, set small value as loss | |
return loss | |
else: | |
neg_log_loss = -torch.sum((output - _loss) * label) / num_occurs | |
l2_loss = self.reg(network) | |
loss = neg_log_loss + l2_loss | |
return loss | |
class ClsCriterion: | |
""" | |
Class of criterion for classification. | |
""" | |
def __init__(self, device: torch.device = None) -> None: | |
""" | |
Set CrossEntropyLoss. | |
Args: | |
device (torch.device): device | |
""" | |
self.device = device | |
self.criterion = nn.CrossEntropyLoss() | |
def __call__( | |
self, | |
outputs: Dict[str, torch.FloatTensor], | |
labels: Dict[str, LabelDict] | |
) -> Dict[str, torch.FloatTensor]: | |
""" | |
Calculate loss. | |
Args: | |
outputs (Dict[str, torch.FloatTensor], optional): output | |
labels (Dict[str, LabelDict]): labels | |
Returns: | |
Dict[str, torch.FloatTensor]: loss for each label and their total loss | |
# No reshape and no cast: | |
output: [64, 2]: torch.float32 | |
label: [64] : torch.int64 | |
label.dtype should be torch.int64, otherwise nn.CrossEntropyLoss() causes error. | |
eg. | |
outputs = {'label_A': [[0.8, 0.2], ...] 'label_B': [[0.7, 0.3]], ...} | |
labels = { 'labels': {'label_A: 1: [1, 1, 0, ...], 'label_B': [0, 0, 1, ...], ...} } | |
-> losses = {total: loss_total, label_A: loss_A, label_B: loss_B, ... } | |
""" | |
_labels = labels['labels'] | |
# loss for each label and total of their losses | |
losses = dict() | |
losses['total'] = torch.tensor([0.0], requires_grad=True).to(self.device) | |
for label_name in labels['labels'].keys(): | |
_output = outputs[label_name] | |
_label = _labels[label_name] | |
_label_loss = self.criterion(_output, _label) | |
losses[label_name] = _label_loss | |
losses['total'] = torch.add(losses['total'], _label_loss) | |
return losses | |
class RegCriterion: | |
""" | |
Class of criterion for regression. | |
""" | |
def __init__(self, criterion_name: str = None, device: torch.device = None) -> None: | |
""" | |
Set MSE, RMSE or MAE. | |
Args: | |
criterion_name (str): 'MSE', 'RMSE', or 'MAE' | |
device (torch.device): device | |
""" | |
self.device = device | |
if criterion_name == 'MSE': | |
self.criterion = nn.MSELoss() | |
elif criterion_name == 'RMSE': | |
self.criterion = RMSELoss() | |
elif criterion_name == 'MAE': | |
self.criterion = nn.L1Loss() | |
else: | |
raise ValueError(f"Invalid criterion for regression: {criterion_name}.") | |
def __call__( | |
self, | |
outputs: Dict[str, torch.FloatTensor], | |
labels: Dict[str, LabelDict] | |
) -> Dict[str, torch.FloatTensor]: | |
""" | |
Calculate loss. | |
Args: | |
Args: | |
outputs (Dict[str, torch.FloatTensor], optional): output | |
labels (Dict[str, LabelDict]): labels | |
Returns: | |
Dict[str, torch.FloatTensor]: loss for each label and their total loss | |
# Reshape and cast | |
output: [64, 1] -> [64]: torch.float32 | |
label: [64]: torch.float64 -> torch.float32 | |
# label.dtype should be torch.float32, otherwise cannot backward. | |
eg. | |
outputs = {'label_A': [[10.8], ...] 'label_B': [[15.7]], ...} | |
labels = {'labels': {'label_A: 1: [10, 9, ...], 'label_B': [12, 17,], ...}} | |
-> losses = {total: loss_total, label_A: loss_A, label_B: loss_B, ... } | |
""" | |
_outputs = {label_name: _output.squeeze() for label_name, _output in outputs.items()} | |
_labels = {label_name: _label.to(torch.float32) for label_name, _label in labels['labels'].items()} | |
# loss for each label and total of their losses | |
losses = dict() | |
losses['total'] = torch.tensor([0.0], requires_grad=True).to(self.device) | |
for label_name in labels['labels'].keys(): | |
_output = _outputs[label_name] | |
_label = _labels[label_name] | |
_label_loss = self.criterion(_output, _label) | |
losses[label_name] = _label_loss | |
losses['total'] = torch.add(losses['total'], _label_loss) | |
return losses | |
class DeepSurvCriterion: | |
""" | |
Class of criterion for deepsurv. | |
""" | |
def __init__(self, device: torch.device = None) -> None: | |
""" | |
Set NegativeLogLikelihood. | |
Args: | |
device (torch.device, optional): device | |
""" | |
self.device = device | |
self.criterion = NegativeLogLikelihood(self.device).to(self.device) | |
def __call__( | |
self, | |
outputs: Dict[str, torch.FloatTensor], | |
labels: Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] | |
) -> Dict[str, torch.FloatTensor]: | |
""" | |
Calculate loss. | |
Args: | |
outputs (Dict[str, torch.FloatTensor], optional): output | |
labels (Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]): labels, periods, and network | |
Returns: | |
Dict[str, torch.FloatTensor]: loss for each label and their total loss | |
# Reshape and no cast | |
output: [64, 1]: torch.float32 | |
label: [64] -> [64, 1]: torch.int64 | |
period: [64] -> [64, 1]: torch.float32 | |
eg. | |
outputs = {'label_A': [[10.8], ...] 'label_B': [[15.7]], ...} | |
labels = { | |
'labels': {'label_A: 1: [1, 0, 1, ...] }, | |
'periods': [5, 10, 7, ...], | |
'network': network | |
} | |
-> losses = {total: loss_total, label_A: loss_A, label_B: loss_B, ... } | |
""" | |
_labels = {label_name: _label.reshape(-1, 1) for label_name, _label in labels['labels'].items()} | |
_periods = labels['periods'].reshape(-1, 1) | |
_network = labels['network'] | |
# loss for each label and total of their losses | |
losses = dict() | |
losses['total'] = torch.tensor([0.0], requires_grad=True).to(self.device) | |
for label_name in labels['labels'].keys(): | |
_output = outputs[label_name] | |
_label = _labels[label_name] | |
_label_loss = self.criterion(_output, _label, _periods, _network) | |
losses[label_name] = _label_loss | |
losses['total'] = torch.add(losses['total'], _label_loss) | |
return losses | |
def set_criterion( | |
criterion_name: str, | |
device: torch.device | |
) -> Union[ClsCriterion, RegCriterion, DeepSurvCriterion]: | |
""" | |
Return criterion class | |
Args: | |
criterion_name (str): criterion name | |
device (torch.device): device | |
Returns: | |
Union[ClsCriterion, RegCriterion, DeepSurvCriterion]: criterion class | |
""" | |
if criterion_name == 'CEL': | |
return ClsCriterion(device=device) | |
elif criterion_name in ['MSE', 'RMSE', 'MAE']: | |
return RegCriterion(criterion_name=criterion_name, device=device) | |
elif criterion_name == 'NLL': | |
return DeepSurvCriterion(device=device) | |
else: | |
raise ValueError(f"Invalid criterion: {criterion_name}.") | |