Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
#!/usr/bin/env python | |
# -*- coding: utf-8 -*-r | |
from collections import OrderedDict | |
import torch | |
import torch.nn as nn | |
from torchvision.ops import MLP | |
import torchvision.models as models | |
from typing import Dict, Optional | |
class BaseNet: | |
""" | |
Class to construct network | |
""" | |
cnn = { | |
'ResNet18': models.resnet18, | |
'ResNet': models.resnet50, | |
'DenseNet': models.densenet161, | |
'EfficientNetB0': models.efficientnet_b0, | |
'EfficientNetB2': models.efficientnet_b2, | |
'EfficientNetB4': models.efficientnet_b4, | |
'EfficientNetB6': models.efficientnet_b6, | |
'EfficientNetV2s': models.efficientnet_v2_s, | |
'EfficientNetV2m': models.efficientnet_v2_m, | |
'EfficientNetV2l': models.efficientnet_v2_l, | |
'ConvNeXtTiny': models.convnext_tiny, | |
'ConvNeXtSmall': models.convnext_small, | |
'ConvNeXtBase': models.convnext_base, | |
'ConvNeXtLarge': models.convnext_large | |
} | |
vit = { | |
'ViTb16': models.vit_b_16, | |
'ViTb32': models.vit_b_32, | |
'ViTl16': models.vit_l_16, | |
'ViTl32': models.vit_l_32, | |
'ViTH14': models.vit_h_14 | |
} | |
net = {**cnn, **vit} | |
_classifier = { | |
'ResNet': 'fc', | |
'DenseNet': 'classifier', | |
'EfficientNet': 'classifier', | |
'ConvNext': 'classifier', | |
'ViT': 'heads' | |
} | |
classifier = { | |
'ResNet18': _classifier['ResNet'], | |
'ResNet': _classifier['ResNet'], | |
'DenseNet': _classifier['DenseNet'], | |
'EfficientNetB0': _classifier['EfficientNet'], | |
'EfficientNetB2': _classifier['EfficientNet'], | |
'EfficientNetB4': _classifier['EfficientNet'], | |
'EfficientNetB6': _classifier['EfficientNet'], | |
'EfficientNetV2s': _classifier['EfficientNet'], | |
'EfficientNetV2m': _classifier['EfficientNet'], | |
'EfficientNetV2l': _classifier['EfficientNet'], | |
'ConvNeXtTiny': _classifier['ConvNext'], | |
'ConvNeXtSmall': _classifier['ConvNext'], | |
'ConvNeXtBase': _classifier['ConvNext'], | |
'ConvNeXtLarge': _classifier['ConvNext'], | |
'ViTb16': _classifier['ViT'], | |
'ViTb32': _classifier['ViT'], | |
'ViTl16': _classifier['ViT'], | |
'ViTl32': _classifier['ViT'], | |
'ViTH14': _classifier['ViT'] | |
} | |
mlp_config = { | |
'hidden_channels': [256, 256, 256], | |
'dropout': 0.2 | |
} | |
DUMMY = nn.Identity() | |
def MLPNet(cls, mlp_num_inputs: int = None, inplace: bool = None) -> MLP: | |
""" | |
Construct MLP. | |
Args: | |
mlp_num_inputs (int): the number of input of MLP | |
inplace (bool, optional): parameter for the activation layer, which can optionally do the operation in-place. Defaults to None. | |
Returns: | |
MLP: MLP | |
""" | |
assert isinstance(mlp_num_inputs, int), f"Invalid number of inputs for MLP: {mlp_num_inputs}." | |
mlp = MLP(in_channels=mlp_num_inputs, hidden_channels=cls.mlp_config['hidden_channels'], inplace=inplace, dropout=cls.mlp_config['dropout']) | |
return mlp | |
def align_in_channels_1ch(cls, net_name: str = None, net: nn.Module = None) -> nn.Module: | |
""" | |
Modify network to handle gray scale image. | |
Args: | |
net_name (str): network name | |
net (nn.Module): network itself | |
Returns: | |
nn.Module: network available for gray scale | |
""" | |
if net_name.startswith('ResNet'): | |
net.conv1.in_channels = 1 | |
net.conv1.weight = nn.Parameter(net.conv1.weight.sum(dim=1).unsqueeze(1)) | |
elif net_name.startswith('DenseNet'): | |
net.features.conv0.in_channels = 1 | |
net.features.conv0.weight = nn.Parameter(net.features.conv0.weight.sum(dim=1).unsqueeze(1)) | |
elif net_name.startswith('Efficient'): | |
net.features[0][0].in_channels = 1 | |
net.features[0][0].weight = nn.Parameter(net.features[0][0].weight.sum(dim=1).unsqueeze(1)) | |
elif net_name.startswith('ConvNeXt'): | |
net.features[0][0].in_channels = 1 | |
net.features[0][0].weight = nn.Parameter(net.features[0][0].weight.sum(dim=1).unsqueeze(1)) | |
elif net_name.startswith('ViT'): | |
net.conv_proj.in_channels = 1 | |
net.conv_proj.weight = nn.Parameter(net.conv_proj.weight.sum(dim=1).unsqueeze(1)) | |
else: | |
raise ValueError(f"No specified net: {net_name}.") | |
return net | |
def set_net( | |
cls, | |
net_name: str = None, | |
in_channel: int = None, | |
vit_image_size: int = None, | |
pretrained: bool = None | |
) -> nn.Module: | |
""" | |
Modify network depending on in_channel and vit_image_size. | |
Args: | |
net_name (str): network name | |
in_channel (int, optional): image channel(any of 1ch or 3ch). Defaults to None. | |
vit_image_size (int, optional): image size which ViT handles if ViT is used. Defaults to None. | |
vit_image_size should be power of patch size. | |
pretrained (bool, optional): True when use pretrained CNN or ViT, otherwise False. Defaults to None. | |
Returns: | |
nn.Module: modified network | |
""" | |
assert net_name in cls.net, f"No specified net: {net_name}." | |
if net_name in cls.cnn: | |
if pretrained: | |
net = cls.cnn[net_name](weights='DEFAULT') | |
else: | |
net = cls.cnn[net_name]() | |
else: | |
# When ViT | |
# always use pretrained | |
net = cls.set_vit(net_name=net_name, vit_image_size=vit_image_size) | |
if in_channel == 1: | |
net = cls.align_in_channels_1ch(net_name=net_name, net=net) | |
return net | |
def set_vit(cls, net_name: str = None, vit_image_size: int = None) -> nn.Module: | |
""" | |
Modify ViT depending on vit_image_size. | |
Args: | |
net_name (str): ViT name | |
vit_image_size (int): image size which ViT handles if ViT is used. | |
Returns: | |
nn.Module: modified ViT | |
""" | |
base_vit = cls.vit[net_name] | |
# pretrained_vit = base_vit(weights=cls.vit_weight[net_name]) | |
pretrained_vit = base_vit(weights='DEFAULT') | |
# Align weight depending on image size | |
weight = pretrained_vit.state_dict() | |
patch_size = int(net_name[-2:]) # 'ViTb16' -> 16 | |
aligned_weight = models.vision_transformer.interpolate_embeddings( | |
image_size=vit_image_size, | |
patch_size=patch_size, | |
model_state=weight | |
) | |
aligned_vit = base_vit(image_size=vit_image_size) # Specify new image size. | |
aligned_vit.load_state_dict(aligned_weight) # Load weight which can handle the new image size. | |
return aligned_vit | |
def construct_extractor( | |
cls, | |
net_name: str = None, | |
mlp_num_inputs: int = None, | |
in_channel: int = None, | |
vit_image_size: int = None, | |
pretrained: bool = None | |
) -> nn.Module: | |
""" | |
Construct extractor of network depending on net_name. | |
Args: | |
net_name (str): network name. | |
mlp_num_inputs (int, optional): number of input of MLP. Defaults to None. | |
in_channel (int, optional): image channel(any of 1ch or 3ch). Defaults to None. | |
vit_image_size (int, optional): image size which ViT handles if ViT is used. Defaults to None. | |
pretrained (bool, optional): True when use pretrained CNN or ViT, otherwise False. Defaults to None. | |
Returns: | |
nn.Module: extractor of network | |
""" | |
if net_name == 'MLP': | |
extractor = cls.MLPNet(mlp_num_inputs=mlp_num_inputs) | |
else: | |
extractor = cls.set_net(net_name=net_name, in_channel=in_channel, vit_image_size=vit_image_size, pretrained=pretrained) | |
setattr(extractor, cls.classifier[net_name], cls.DUMMY) # Replace classifier with DUMMY(=nn.Identity()). | |
return extractor | |
def get_classifier(cls, net_name: str) -> nn.Module: | |
""" | |
Get classifier of network depending on net_name. | |
Args: | |
net_name (str): network name | |
Returns: | |
nn.Module: classifier of network | |
""" | |
net = cls.net[net_name]() | |
classifier = getattr(net, cls.classifier[net_name]) | |
return classifier | |
def construct_multi_classifier(cls, net_name: str = None, num_outputs_for_label: Dict[str, int] = None) -> nn.ModuleDict: | |
""" | |
Construct classifier for multi-label. | |
Args: | |
net_name (str): network name | |
num_outputs_for_label (Dict[str, int]): number of outputs for each label | |
Returns: | |
nn.ModuleDict: classifier for multi-label | |
""" | |
classifiers = dict() | |
if net_name == 'MLP': | |
in_features = cls.mlp_config['hidden_channels'][-1] | |
for label_name, num_outputs in num_outputs_for_label.items(): | |
classifiers[label_name] = nn.Linear(in_features, num_outputs) | |
elif net_name.startswith('ResNet') or net_name.startswith('DenseNet'): | |
base_classifier = cls.get_classifier(net_name) | |
in_features = base_classifier.in_features | |
for label_name, num_outputs in num_outputs_for_label.items(): | |
classifiers[label_name] = nn.Linear(in_features, num_outputs) | |
elif net_name.startswith('EfficientNet'): | |
base_classifier = cls.get_classifier(net_name) | |
dropout = base_classifier[0].p | |
in_features = base_classifier[1].in_features | |
for label_name, num_outputs in num_outputs_for_label.items(): | |
classifiers[label_name] = nn.Sequential( | |
nn.Dropout(p=dropout, inplace=False), | |
nn.Linear(in_features, num_outputs) | |
) | |
elif net_name.startswith('ConvNeXt'): | |
base_classifier = cls.get_classifier(net_name) | |
layer_norm = base_classifier[0] | |
flatten = base_classifier[1] | |
in_features = base_classifier[2].in_features | |
for label_name, num_outputs in num_outputs_for_label.items(): | |
# Shape is changed before nn.Linear. | |
classifiers[label_name] = nn.Sequential( | |
layer_norm, | |
flatten, | |
nn.Linear(in_features, num_outputs) | |
) | |
elif net_name.startswith('ViT'): | |
base_classifier = cls.get_classifier(net_name) | |
in_features = base_classifier.head.in_features | |
for label_name, num_outputs in num_outputs_for_label.items(): | |
classifiers[label_name] = nn.Sequential( | |
OrderedDict([ | |
('head', nn.Linear(in_features, num_outputs)) | |
]) | |
) | |
else: | |
raise ValueError(f"No specified net: {net_name}.") | |
multi_classifier = nn.ModuleDict(classifiers) | |
return multi_classifier | |
def get_classifier_in_features(cls, net_name: str) -> int: | |
""" | |
Return in_feature of network indicating by net_name. | |
This class is used in class MultiNetFusion() only. | |
Args: | |
net_name (str): net_name | |
Returns: | |
int : in_feature | |
Required: | |
classifier.in_feature | |
classifier.[1].in_features | |
classifier.[2].in_features | |
classifier.head.in_features | |
""" | |
if net_name == 'MLP': | |
in_features = cls.mlp_config['hidden_channels'][-1] | |
elif net_name.startswith('ResNet') or net_name.startswith('DenseNet'): | |
base_classifier = cls.get_classifier(net_name) | |
in_features = base_classifier.in_features | |
elif net_name.startswith('EfficientNet'): | |
base_classifier = cls.get_classifier(net_name) | |
in_features = base_classifier[1].in_features | |
elif net_name.startswith('ConvNeXt'): | |
base_classifier = cls.get_classifier(net_name) | |
in_features = base_classifier[2].in_features | |
elif net_name.startswith('ViT'): | |
base_classifier = cls.get_classifier(net_name) | |
in_features = base_classifier.head.in_features | |
else: | |
raise ValueError(f"No specified net: {net_name}.") | |
return in_features | |
def construct_aux_module(cls, net_name: str) -> nn.Sequential: | |
""" | |
Construct module to align the shape of feature from extractor depending on network. | |
Actually, only when net_name == 'ConvNeXt'. | |
Because ConvNeXt has the process of aligning the dimensions in its classifier. | |
Needs to align shape of the feature extractor when ConvNeXt | |
(classifier): | |
Sequential( | |
(0): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True) | |
(1): Flatten(start_dim=1, end_dim=-1) | |
(2): Linear(in_features=768, out_features=1000, bias=True) | |
) | |
Args: | |
net_name (str): net name | |
Returns: | |
nn.Module: layers such that they align the dimension of the output from the extractor like the original ConvNeXt. | |
""" | |
aux_module = cls.DUMMY | |
if net_name.startswith('ConvNeXt'): | |
base_classifier = cls.get_classifier(net_name) | |
layer_norm = base_classifier[0] | |
flatten = base_classifier[1] | |
aux_module = nn.Sequential( | |
layer_norm, | |
flatten | |
) | |
return aux_module | |
def get_last_extractor(cls, net: nn.Module = None, mlp: str = None, net_name: str = None) -> nn.Module: | |
""" | |
Return the last extractor of network. | |
This is for Grad-CAM. | |
net should be one loaded weight. | |
Args: | |
net (nn.Module): network itself | |
mlp (str): 'MLP', otherwise None | |
net_name (str): network name | |
Returns: | |
nn.Module: last extractor of network | |
""" | |
assert (net_name is not None), f"Network does not contain CNN or ViT: mlp={mlp}, net={net_name}." | |
_extractor = net.extractor_net | |
if net_name.startswith('ResNet'): | |
last_extractor = _extractor.layer4[-1] | |
elif net_name.startswith('DenseNet'): | |
last_extractor = _extractor.features.denseblock4.denselayer24 | |
elif net_name.startswith('EfficientNet'): | |
last_extractor = _extractor.features[-1] | |
elif net_name.startswith('ConvNeXt'): | |
last_extractor = _extractor.features[-1][-1].block | |
elif net_name.startswith('ViT'): | |
last_extractor = _extractor.encoder.layers[-1] | |
else: | |
raise ValueError(f"Cannot get last extractor of net: {net_name}.") | |
return last_extractor | |
class MultiMixin: | |
""" | |
Class to define auxiliary function to handle multi-label. | |
""" | |
def multi_forward(self, out_features: int) -> Dict[str, float]: | |
""" | |
Forward out_features to classifier for each label. | |
Args: | |
out_features (int): output from extractor | |
Returns: | |
Dict[str, float]: output of classifier of each label | |
""" | |
output = dict() | |
for label_name, classifier in self.multi_classifier.items(): | |
output[label_name] = classifier(out_features) | |
return output | |
class MultiWidget(nn.Module, BaseNet, MultiMixin): | |
""" | |
Class for a widget to inherit multiple classes simultaneously. | |
""" | |
pass | |
class MultiNet(MultiWidget): | |
""" | |
Model of MLP, CNN or ViT. | |
""" | |
def __init__( | |
self, | |
net_name: str = None, | |
num_outputs_for_label: Dict[str, int] = None, | |
mlp_num_inputs: int = None, | |
in_channel: int = None, | |
vit_image_size: int = None, | |
pretrained: bool = None | |
) -> None: | |
""" | |
Args: | |
net_name (str): MLP, CNN or ViT name | |
num_outputs_for_label (Dict[str, int]): number of classes for each label | |
mlp_num_inputs (int): number of input of MLP. | |
in_channel (int): number of image channel, ie gray scale(=1) or color image(=3). | |
vit_image_size (int): image size to be input to ViT. | |
pretrained (bool): True when use pretrained CNN or ViT, otherwise False. | |
""" | |
super().__init__() | |
self.net_name = net_name | |
self.num_outputs_for_label = num_outputs_for_label | |
self.mlp_num_inputs = mlp_num_inputs | |
self.in_channel = in_channel | |
self.vit_image_size = vit_image_size | |
self.pretrained = pretrained | |
# self.extractor_net = MLP or CVmodel | |
self.extractor_net = self.construct_extractor( | |
net_name=self.net_name, | |
mlp_num_inputs=self.mlp_num_inputs, | |
in_channel=self.in_channel, | |
vit_image_size=self.vit_image_size, | |
pretrained=self.pretrained | |
) | |
self.multi_classifier = self.construct_multi_classifier(net_name=self.net_name, num_outputs_for_label=self.num_outputs_for_label) | |
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: | |
""" | |
Forward. | |
Args: | |
x (torch.Tensor): tabular data or image | |
Returns: | |
Dict[str, torch.Tensor]: output | |
""" | |
out_features = self.extractor_net(x) | |
output = self.multi_forward(out_features) | |
return output | |
class MultiNetFusion(MultiWidget): | |
""" | |
Fusion model of MLP and CNN or ViT. | |
""" | |
def __init__( | |
self, | |
net_name: str = None, | |
num_outputs_for_label: Dict[str, int] = None, | |
mlp_num_inputs: int = None, | |
in_channel: int = None, | |
vit_image_size: int = None, | |
pretrained: bool = None | |
) -> None: | |
""" | |
Args: | |
net_name (str): CNN or ViT name. It is clear that MLP is used in fusion model. | |
num_outputs_for_label (Dict[str, int]): number of classes for each label | |
mlp_num_inputs (int): number of input of MLP. Defaults to None. | |
in_channel (int): number of image channel, ie gray scale(=1) or color image(=3). | |
vit_image_size (int): image size to be input to ViT. | |
pretrained (bool): True when use pretrained CNN or ViT, otherwise False. | |
""" | |
assert (net_name != 'MLP'), 'net_name should not be MLP.' | |
super().__init__() | |
self.net_name = net_name | |
self.num_outputs_for_label = num_outputs_for_label | |
self.mlp_num_inputs = mlp_num_inputs | |
self.in_channel = in_channel | |
self.vit_image_size = vit_image_size | |
self.pretrained = pretrained | |
# Extractor of MLP and Net | |
self.extractor_mlp = self.construct_extractor(net_name='MLP', mlp_num_inputs=self.mlp_num_inputs) | |
self.extractor_net = self.construct_extractor( | |
net_name=self.net_name, | |
in_channel=self.in_channel, | |
vit_image_size=self.vit_image_size, | |
pretrained=self.pretrained | |
) | |
self.aux_module = self.construct_aux_module(self.net_name) | |
# Intermediate MLP | |
self.in_features_from_mlp = self.get_classifier_in_features('MLP') | |
self.in_features_from_net = self.get_classifier_in_features(self.net_name) | |
self.inter_mlp_in_feature = self.in_features_from_mlp + self.in_features_from_net | |
self.inter_mlp = self.MLPNet(mlp_num_inputs=self.inter_mlp_in_feature, inplace=False) | |
# Multi classifier | |
self.multi_classifier = self.construct_multi_classifier(net_name='MLP', num_outputs_for_label=num_outputs_for_label) | |
def forward(self, x_mlp: torch.Tensor, x_net: torch.Tensor) -> Dict[str, torch.Tensor]: | |
""" | |
Forward. | |
Args: | |
x_mlp (torch.Tensor): tabular data | |
x_net (torch.Tensor): image | |
Returns: | |
Dict[str, torch.Tensor]: output | |
""" | |
out_mlp = self.extractor_mlp(x_mlp) | |
out_net = self.extractor_net(x_net) | |
out_net = self.aux_module(out_net) | |
out_features = torch.cat([out_mlp, out_net], dim=1) | |
out_features = self.inter_mlp(out_features) | |
output = self.multi_forward(out_features) | |
return output | |
def create_net( | |
mlp: Optional[str] = None, | |
net: Optional[str] = None, | |
num_outputs_for_label: Dict[str, int] = None, | |
mlp_num_inputs: int = None, | |
in_channel: int = None, | |
vit_image_size: int = None, | |
pretrained: bool = None | |
) -> nn.Module: | |
""" | |
Create network. | |
Args: | |
mlp (Optional[str]): 'MLP' or None | |
net (Optional[str]): CNN, ViT name or None | |
num_outputs_for_label (Dict[str, int]): number of outputs for each label | |
mlp_num_inputs (int): number of input of MLP. | |
in_channel (int): number of image channel, ie gray scale(=1) or color image(=3). | |
vit_image_size (int): image size to be input to ViT. | |
pretrained (bool): True when use pretrained CNN or ViT, otherwise False. | |
Returns: | |
nn.Module: network | |
""" | |
_isMLPModel = (mlp is not None) and (net is None) | |
_isCVModel = (mlp is None) and (net is not None) | |
_isFusion = (mlp is not None) and (net is not None) | |
if _isMLPModel: | |
multi_net = MultiNet( | |
net_name='MLP', | |
num_outputs_for_label=num_outputs_for_label, | |
mlp_num_inputs=mlp_num_inputs, | |
in_channel=in_channel, | |
vit_image_size=vit_image_size, | |
pretrained=False # No need of pretrained for MLP | |
) | |
elif _isCVModel: | |
multi_net = MultiNet( | |
net_name=net, | |
num_outputs_for_label=num_outputs_for_label, | |
mlp_num_inputs=mlp_num_inputs, | |
in_channel=in_channel, | |
vit_image_size=vit_image_size, | |
pretrained=pretrained | |
) | |
elif _isFusion: | |
multi_net = MultiNetFusion( | |
net_name=net, | |
num_outputs_for_label=num_outputs_for_label, | |
mlp_num_inputs=mlp_num_inputs, | |
in_channel=in_channel, | |
vit_image_size=vit_image_size, | |
pretrained=pretrained | |
) | |
else: | |
raise ValueError(f"Invalid model type: mlp={mlp}, net={net}.") | |
return multi_net | |