from transformers import PreTrainedModel from .unet3d import U_Net, U_Net_DeepSup from .UNetConfigs import UNet3DConfig, UNetMSS3DConfig class UNet3D(PreTrainedModel): config_class = UNet3DConfig def __init__(self, config): super().__init__(config) self.model = U_Net( in_ch=config.in_ch, out_ch=config.out_ch, init_features=config.init_features) def forward(self, x): return self.model(x) class UNetMSS3D(PreTrainedModel): config_class = UNetMSS3DConfig def __init__(self, config): super().__init__(config) self.model = U_Net_DeepSup( in_ch=config.in_ch, out_ch=config.out_ch, init_features=config.init_features) def forward(self, x): return self.model(x)