MedicalAILabo's picture
Upload app.py and lib.
1f53a4c
raw
history blame
No virus
24.6 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-r
from collections import OrderedDict
import torch
import torch.nn as nn
from torchvision.ops import MLP
import torchvision.models as models
from typing import Dict, Optional
class BaseNet:
"""
Class to construct network
"""
cnn = {
'ResNet18': models.resnet18,
'ResNet': models.resnet50,
'DenseNet': models.densenet161,
'EfficientNetB0': models.efficientnet_b0,
'EfficientNetB2': models.efficientnet_b2,
'EfficientNetB4': models.efficientnet_b4,
'EfficientNetB6': models.efficientnet_b6,
'EfficientNetV2s': models.efficientnet_v2_s,
'EfficientNetV2m': models.efficientnet_v2_m,
'EfficientNetV2l': models.efficientnet_v2_l,
'ConvNeXtTiny': models.convnext_tiny,
'ConvNeXtSmall': models.convnext_small,
'ConvNeXtBase': models.convnext_base,
'ConvNeXtLarge': models.convnext_large
}
vit = {
'ViTb16': models.vit_b_16,
'ViTb32': models.vit_b_32,
'ViTl16': models.vit_l_16,
'ViTl32': models.vit_l_32,
'ViTH14': models.vit_h_14
}
net = {**cnn, **vit}
_classifier = {
'ResNet': 'fc',
'DenseNet': 'classifier',
'EfficientNet': 'classifier',
'ConvNext': 'classifier',
'ViT': 'heads'
}
classifier = {
'ResNet18': _classifier['ResNet'],
'ResNet': _classifier['ResNet'],
'DenseNet': _classifier['DenseNet'],
'EfficientNetB0': _classifier['EfficientNet'],
'EfficientNetB2': _classifier['EfficientNet'],
'EfficientNetB4': _classifier['EfficientNet'],
'EfficientNetB6': _classifier['EfficientNet'],
'EfficientNetV2s': _classifier['EfficientNet'],
'EfficientNetV2m': _classifier['EfficientNet'],
'EfficientNetV2l': _classifier['EfficientNet'],
'ConvNeXtTiny': _classifier['ConvNext'],
'ConvNeXtSmall': _classifier['ConvNext'],
'ConvNeXtBase': _classifier['ConvNext'],
'ConvNeXtLarge': _classifier['ConvNext'],
'ViTb16': _classifier['ViT'],
'ViTb32': _classifier['ViT'],
'ViTl16': _classifier['ViT'],
'ViTl32': _classifier['ViT'],
'ViTH14': _classifier['ViT']
}
mlp_config = {
'hidden_channels': [256, 256, 256],
'dropout': 0.2
}
DUMMY = nn.Identity()
@classmethod
def MLPNet(cls, mlp_num_inputs: int = None, inplace: bool = None) -> MLP:
"""
Construct MLP.
Args:
mlp_num_inputs (int): the number of input of MLP
inplace (bool, optional): parameter for the activation layer, which can optionally do the operation in-place. Defaults to None.
Returns:
MLP: MLP
"""
assert isinstance(mlp_num_inputs, int), f"Invalid number of inputs for MLP: {mlp_num_inputs}."
mlp = MLP(in_channels=mlp_num_inputs, hidden_channels=cls.mlp_config['hidden_channels'], inplace=inplace, dropout=cls.mlp_config['dropout'])
return mlp
@classmethod
def align_in_channels_1ch(cls, net_name: str = None, net: nn.Module = None) -> nn.Module:
"""
Modify network to handle gray scale image.
Args:
net_name (str): network name
net (nn.Module): network itself
Returns:
nn.Module: network available for gray scale
"""
if net_name.startswith('ResNet'):
net.conv1.in_channels = 1
net.conv1.weight = nn.Parameter(net.conv1.weight.sum(dim=1).unsqueeze(1))
elif net_name.startswith('DenseNet'):
net.features.conv0.in_channels = 1
net.features.conv0.weight = nn.Parameter(net.features.conv0.weight.sum(dim=1).unsqueeze(1))
elif net_name.startswith('Efficient'):
net.features[0][0].in_channels = 1
net.features[0][0].weight = nn.Parameter(net.features[0][0].weight.sum(dim=1).unsqueeze(1))
elif net_name.startswith('ConvNeXt'):
net.features[0][0].in_channels = 1
net.features[0][0].weight = nn.Parameter(net.features[0][0].weight.sum(dim=1).unsqueeze(1))
elif net_name.startswith('ViT'):
net.conv_proj.in_channels = 1
net.conv_proj.weight = nn.Parameter(net.conv_proj.weight.sum(dim=1).unsqueeze(1))
else:
raise ValueError(f"No specified net: {net_name}.")
return net
@classmethod
def set_net(
cls,
net_name: str = None,
in_channel: int = None,
vit_image_size: int = None,
pretrained: bool = None
) -> nn.Module:
"""
Modify network depending on in_channel and vit_image_size.
Args:
net_name (str): network name
in_channel (int, optional): image channel(any of 1ch or 3ch). Defaults to None.
vit_image_size (int, optional): image size which ViT handles if ViT is used. Defaults to None.
vit_image_size should be power of patch size.
pretrained (bool, optional): True when use pretrained CNN or ViT, otherwise False. Defaults to None.
Returns:
nn.Module: modified network
"""
assert net_name in cls.net, f"No specified net: {net_name}."
if net_name in cls.cnn:
if pretrained:
net = cls.cnn[net_name](weights='DEFAULT')
else:
net = cls.cnn[net_name]()
else:
# When ViT
# always use pretrained
net = cls.set_vit(net_name=net_name, vit_image_size=vit_image_size)
if in_channel == 1:
net = cls.align_in_channels_1ch(net_name=net_name, net=net)
return net
@classmethod
def set_vit(cls, net_name: str = None, vit_image_size: int = None) -> nn.Module:
"""
Modify ViT depending on vit_image_size.
Args:
net_name (str): ViT name
vit_image_size (int): image size which ViT handles if ViT is used.
Returns:
nn.Module: modified ViT
"""
base_vit = cls.vit[net_name]
# pretrained_vit = base_vit(weights=cls.vit_weight[net_name])
pretrained_vit = base_vit(weights='DEFAULT')
# Align weight depending on image size
weight = pretrained_vit.state_dict()
patch_size = int(net_name[-2:]) # 'ViTb16' -> 16
aligned_weight = models.vision_transformer.interpolate_embeddings(
image_size=vit_image_size,
patch_size=patch_size,
model_state=weight
)
aligned_vit = base_vit(image_size=vit_image_size) # Specify new image size.
aligned_vit.load_state_dict(aligned_weight) # Load weight which can handle the new image size.
return aligned_vit
@classmethod
def construct_extractor(
cls,
net_name: str = None,
mlp_num_inputs: int = None,
in_channel: int = None,
vit_image_size: int = None,
pretrained: bool = None
) -> nn.Module:
"""
Construct extractor of network depending on net_name.
Args:
net_name (str): network name.
mlp_num_inputs (int, optional): number of input of MLP. Defaults to None.
in_channel (int, optional): image channel(any of 1ch or 3ch). Defaults to None.
vit_image_size (int, optional): image size which ViT handles if ViT is used. Defaults to None.
pretrained (bool, optional): True when use pretrained CNN or ViT, otherwise False. Defaults to None.
Returns:
nn.Module: extractor of network
"""
if net_name == 'MLP':
extractor = cls.MLPNet(mlp_num_inputs=mlp_num_inputs)
else:
extractor = cls.set_net(net_name=net_name, in_channel=in_channel, vit_image_size=vit_image_size, pretrained=pretrained)
setattr(extractor, cls.classifier[net_name], cls.DUMMY) # Replace classifier with DUMMY(=nn.Identity()).
return extractor
@classmethod
def get_classifier(cls, net_name: str) -> nn.Module:
"""
Get classifier of network depending on net_name.
Args:
net_name (str): network name
Returns:
nn.Module: classifier of network
"""
net = cls.net[net_name]()
classifier = getattr(net, cls.classifier[net_name])
return classifier
@classmethod
def construct_multi_classifier(cls, net_name: str = None, num_outputs_for_label: Dict[str, int] = None) -> nn.ModuleDict:
"""
Construct classifier for multi-label.
Args:
net_name (str): network name
num_outputs_for_label (Dict[str, int]): number of outputs for each label
Returns:
nn.ModuleDict: classifier for multi-label
"""
classifiers = dict()
if net_name == 'MLP':
in_features = cls.mlp_config['hidden_channels'][-1]
for label_name, num_outputs in num_outputs_for_label.items():
classifiers[label_name] = nn.Linear(in_features, num_outputs)
elif net_name.startswith('ResNet') or net_name.startswith('DenseNet'):
base_classifier = cls.get_classifier(net_name)
in_features = base_classifier.in_features
for label_name, num_outputs in num_outputs_for_label.items():
classifiers[label_name] = nn.Linear(in_features, num_outputs)
elif net_name.startswith('EfficientNet'):
base_classifier = cls.get_classifier(net_name)
dropout = base_classifier[0].p
in_features = base_classifier[1].in_features
for label_name, num_outputs in num_outputs_for_label.items():
classifiers[label_name] = nn.Sequential(
nn.Dropout(p=dropout, inplace=False),
nn.Linear(in_features, num_outputs)
)
elif net_name.startswith('ConvNeXt'):
base_classifier = cls.get_classifier(net_name)
layer_norm = base_classifier[0]
flatten = base_classifier[1]
in_features = base_classifier[2].in_features
for label_name, num_outputs in num_outputs_for_label.items():
# Shape is changed before nn.Linear.
classifiers[label_name] = nn.Sequential(
layer_norm,
flatten,
nn.Linear(in_features, num_outputs)
)
elif net_name.startswith('ViT'):
base_classifier = cls.get_classifier(net_name)
in_features = base_classifier.head.in_features
for label_name, num_outputs in num_outputs_for_label.items():
classifiers[label_name] = nn.Sequential(
OrderedDict([
('head', nn.Linear(in_features, num_outputs))
])
)
else:
raise ValueError(f"No specified net: {net_name}.")
multi_classifier = nn.ModuleDict(classifiers)
return multi_classifier
@classmethod
def get_classifier_in_features(cls, net_name: str) -> int:
"""
Return in_feature of network indicating by net_name.
This class is used in class MultiNetFusion() only.
Args:
net_name (str): net_name
Returns:
int : in_feature
Required:
classifier.in_feature
classifier.[1].in_features
classifier.[2].in_features
classifier.head.in_features
"""
if net_name == 'MLP':
in_features = cls.mlp_config['hidden_channels'][-1]
elif net_name.startswith('ResNet') or net_name.startswith('DenseNet'):
base_classifier = cls.get_classifier(net_name)
in_features = base_classifier.in_features
elif net_name.startswith('EfficientNet'):
base_classifier = cls.get_classifier(net_name)
in_features = base_classifier[1].in_features
elif net_name.startswith('ConvNeXt'):
base_classifier = cls.get_classifier(net_name)
in_features = base_classifier[2].in_features
elif net_name.startswith('ViT'):
base_classifier = cls.get_classifier(net_name)
in_features = base_classifier.head.in_features
else:
raise ValueError(f"No specified net: {net_name}.")
return in_features
@classmethod
def construct_aux_module(cls, net_name: str) -> nn.Sequential:
"""
Construct module to align the shape of feature from extractor depending on network.
Actually, only when net_name == 'ConvNeXt'.
Because ConvNeXt has the process of aligning the dimensions in its classifier.
Needs to align shape of the feature extractor when ConvNeXt
(classifier):
Sequential(
(0): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True)
(1): Flatten(start_dim=1, end_dim=-1)
(2): Linear(in_features=768, out_features=1000, bias=True)
)
Args:
net_name (str): net name
Returns:
nn.Module: layers such that they align the dimension of the output from the extractor like the original ConvNeXt.
"""
aux_module = cls.DUMMY
if net_name.startswith('ConvNeXt'):
base_classifier = cls.get_classifier(net_name)
layer_norm = base_classifier[0]
flatten = base_classifier[1]
aux_module = nn.Sequential(
layer_norm,
flatten
)
return aux_module
@classmethod
def get_last_extractor(cls, net: nn.Module = None, mlp: str = None, net_name: str = None) -> nn.Module:
"""
Return the last extractor of network.
This is for Grad-CAM.
net should be one loaded weight.
Args:
net (nn.Module): network itself
mlp (str): 'MLP', otherwise None
net_name (str): network name
Returns:
nn.Module: last extractor of network
"""
assert (net_name is not None), f"Network does not contain CNN or ViT: mlp={mlp}, net={net_name}."
_extractor = net.extractor_net
if net_name.startswith('ResNet'):
last_extractor = _extractor.layer4[-1]
elif net_name.startswith('DenseNet'):
last_extractor = _extractor.features.denseblock4.denselayer24
elif net_name.startswith('EfficientNet'):
last_extractor = _extractor.features[-1]
elif net_name.startswith('ConvNeXt'):
last_extractor = _extractor.features[-1][-1].block
elif net_name.startswith('ViT'):
last_extractor = _extractor.encoder.layers[-1]
else:
raise ValueError(f"Cannot get last extractor of net: {net_name}.")
return last_extractor
class MultiMixin:
"""
Class to define auxiliary function to handle multi-label.
"""
def multi_forward(self, out_features: int) -> Dict[str, float]:
"""
Forward out_features to classifier for each label.
Args:
out_features (int): output from extractor
Returns:
Dict[str, float]: output of classifier of each label
"""
output = dict()
for label_name, classifier in self.multi_classifier.items():
output[label_name] = classifier(out_features)
return output
class MultiWidget(nn.Module, BaseNet, MultiMixin):
"""
Class for a widget to inherit multiple classes simultaneously.
"""
pass
class MultiNet(MultiWidget):
"""
Model of MLP, CNN or ViT.
"""
def __init__(
self,
net_name: str = None,
num_outputs_for_label: Dict[str, int] = None,
mlp_num_inputs: int = None,
in_channel: int = None,
vit_image_size: int = None,
pretrained: bool = None
) -> None:
"""
Args:
net_name (str): MLP, CNN or ViT name
num_outputs_for_label (Dict[str, int]): number of classes for each label
mlp_num_inputs (int): number of input of MLP.
in_channel (int): number of image channel, ie gray scale(=1) or color image(=3).
vit_image_size (int): image size to be input to ViT.
pretrained (bool): True when use pretrained CNN or ViT, otherwise False.
"""
super().__init__()
self.net_name = net_name
self.num_outputs_for_label = num_outputs_for_label
self.mlp_num_inputs = mlp_num_inputs
self.in_channel = in_channel
self.vit_image_size = vit_image_size
self.pretrained = pretrained
# self.extractor_net = MLP or CVmodel
self.extractor_net = self.construct_extractor(
net_name=self.net_name,
mlp_num_inputs=self.mlp_num_inputs,
in_channel=self.in_channel,
vit_image_size=self.vit_image_size,
pretrained=self.pretrained
)
self.multi_classifier = self.construct_multi_classifier(net_name=self.net_name, num_outputs_for_label=self.num_outputs_for_label)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Forward.
Args:
x (torch.Tensor): tabular data or image
Returns:
Dict[str, torch.Tensor]: output
"""
out_features = self.extractor_net(x)
output = self.multi_forward(out_features)
return output
class MultiNetFusion(MultiWidget):
"""
Fusion model of MLP and CNN or ViT.
"""
def __init__(
self,
net_name: str = None,
num_outputs_for_label: Dict[str, int] = None,
mlp_num_inputs: int = None,
in_channel: int = None,
vit_image_size: int = None,
pretrained: bool = None
) -> None:
"""
Args:
net_name (str): CNN or ViT name. It is clear that MLP is used in fusion model.
num_outputs_for_label (Dict[str, int]): number of classes for each label
mlp_num_inputs (int): number of input of MLP. Defaults to None.
in_channel (int): number of image channel, ie gray scale(=1) or color image(=3).
vit_image_size (int): image size to be input to ViT.
pretrained (bool): True when use pretrained CNN or ViT, otherwise False.
"""
assert (net_name != 'MLP'), 'net_name should not be MLP.'
super().__init__()
self.net_name = net_name
self.num_outputs_for_label = num_outputs_for_label
self.mlp_num_inputs = mlp_num_inputs
self.in_channel = in_channel
self.vit_image_size = vit_image_size
self.pretrained = pretrained
# Extractor of MLP and Net
self.extractor_mlp = self.construct_extractor(net_name='MLP', mlp_num_inputs=self.mlp_num_inputs)
self.extractor_net = self.construct_extractor(
net_name=self.net_name,
in_channel=self.in_channel,
vit_image_size=self.vit_image_size,
pretrained=self.pretrained
)
self.aux_module = self.construct_aux_module(self.net_name)
# Intermediate MLP
self.in_features_from_mlp = self.get_classifier_in_features('MLP')
self.in_features_from_net = self.get_classifier_in_features(self.net_name)
self.inter_mlp_in_feature = self.in_features_from_mlp + self.in_features_from_net
self.inter_mlp = self.MLPNet(mlp_num_inputs=self.inter_mlp_in_feature, inplace=False)
# Multi classifier
self.multi_classifier = self.construct_multi_classifier(net_name='MLP', num_outputs_for_label=num_outputs_for_label)
def forward(self, x_mlp: torch.Tensor, x_net: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Forward.
Args:
x_mlp (torch.Tensor): tabular data
x_net (torch.Tensor): image
Returns:
Dict[str, torch.Tensor]: output
"""
out_mlp = self.extractor_mlp(x_mlp)
out_net = self.extractor_net(x_net)
out_net = self.aux_module(out_net)
out_features = torch.cat([out_mlp, out_net], dim=1)
out_features = self.inter_mlp(out_features)
output = self.multi_forward(out_features)
return output
def create_net(
mlp: Optional[str] = None,
net: Optional[str] = None,
num_outputs_for_label: Dict[str, int] = None,
mlp_num_inputs: int = None,
in_channel: int = None,
vit_image_size: int = None,
pretrained: bool = None
) -> nn.Module:
"""
Create network.
Args:
mlp (Optional[str]): 'MLP' or None
net (Optional[str]): CNN, ViT name or None
num_outputs_for_label (Dict[str, int]): number of outputs for each label
mlp_num_inputs (int): number of input of MLP.
in_channel (int): number of image channel, ie gray scale(=1) or color image(=3).
vit_image_size (int): image size to be input to ViT.
pretrained (bool): True when use pretrained CNN or ViT, otherwise False.
Returns:
nn.Module: network
"""
_isMLPModel = (mlp is not None) and (net is None)
_isCVModel = (mlp is None) and (net is not None)
_isFusion = (mlp is not None) and (net is not None)
if _isMLPModel:
multi_net = MultiNet(
net_name='MLP',
num_outputs_for_label=num_outputs_for_label,
mlp_num_inputs=mlp_num_inputs,
in_channel=in_channel,
vit_image_size=vit_image_size,
pretrained=False # No need of pretrained for MLP
)
elif _isCVModel:
multi_net = MultiNet(
net_name=net,
num_outputs_for_label=num_outputs_for_label,
mlp_num_inputs=mlp_num_inputs,
in_channel=in_channel,
vit_image_size=vit_image_size,
pretrained=pretrained
)
elif _isFusion:
multi_net = MultiNetFusion(
net_name=net,
num_outputs_for_label=num_outputs_for_label,
mlp_num_inputs=mlp_num_inputs,
in_channel=in_channel,
vit_image_size=vit_image_size,
pretrained=pretrained
)
else:
raise ValueError(f"Invalid model type: mlp={mlp}, net={net}.")
return multi_net