Glaucoma-Detection / pipline.py
bigmed@bigmed
monkey patched timm swintransformer model
8ef1fbf
raw
history blame
No virus
8.5 kB
#### This is an implmentation of deeplabv3 plus for retina detection
import torch
import torchvision
from torch.nn import functional as F
import torch.nn as nn
import numpy as np
import cv2
from skimage.measure import label, regionprops
import torch
from collections import namedtuple
# check you have the right version of timm
# assert timm.__version__ == "0.3.2"
from timm.models.swin_transformer import swin_base_patch4_window12_384_in22k, SwinTransformer
torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
pad_value = 10
def forward_features(self, x):
x = self.patch_embed(x)
if self.absolute_pos_embed is not None:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
hide=[]
for layer in self.layers:
x = layer(x)
#print(x.shape)
hide.append(x)
#x = self.layers(x)
x = self.norm(x) # B L C
return hide
def forward(self, x):
x = self.forward_features(x)
#x = self.forward_head(x)
return x
SwinTransformer.forward_features = forward_features
SwinTransformer.forward = forward
def extract_regions_Last(img_test, ytruth, pad1=pad_value, pad2=pad_value, pad3=pad_value, pad4=pad_value):
y_truth_copy = ytruth.copy()
y_truth_copy[y_truth_copy == 2] = 1
label_img = label(y_truth_copy)
regions = regionprops(label_img)
max_Area = -1
cropped_results = dict()
for props in regions:
if props.area > max_Area:
max_Area = props.area
minr, minc, maxr, maxc = props.bbox
bx = (minc, maxc, maxc, minc, minc)
by = (minr, minr, maxr, maxr, minr)
# print(minr,maxr)
# print(bx)
# ax.plot(bx, by, '-b', linewidth=2.5)
# cropped_image= pred_class[minr-pad:maxr+pad, minc-pad:maxc+pad]
# cropped_pred_mask = pred_class[minr - pad:maxr + pad, minc - pad:maxc + pad]
if minr - pad1 < 0:
pad1 = 5
if minr - pad1 < 0:
pad1 = 0
if minc - pad2 < 0:
pad2 = 5
if minc - pad2 < 0:
pad2 = 0
if maxr + pad3 > label_img.shape[0]:
pad3 = 5
if maxr + pad3 > label_img.shape[0]:
pad3 = 0
if maxc + pad4 > label_img.shape[1]:
pad4 = 5
if maxc + pad4 > label_img.shape[1]:
pad4 = 0
cropped_image = img_test[minr - pad1:maxr + pad3, minc - pad2:maxc + pad4, :]
cropped_truth = ytruth[minr - pad1:maxr + pad3, minc - pad2:maxc + pad4]
txcordi = []
txcordi.append(minr - pad1)
txcordi.append(maxr + pad3)
txcordi.append(minc - pad2)
txcordi.append(maxc + pad4)
cropped_results['image'] = cropped_image
cropped_results['truth'] = cropped_truth
cropped_results['cord'] = txcordi
return cropped_results
class BasicBlock(nn.Module):
def __init__(self, channel_num):
super(BasicBlock, self).__init__()
# TODO: 3x3 convolution -> relu
# the input and output channel number is channel_num
self.conv_block1 = nn.Sequential(
nn.Conv2d(channel_num, 48, 1, padding=0),
nn.GroupNorm(num_groups=8, num_channels=48),
nn.GELU(),
)
self.conv_block2 = nn.Sequential(
nn.Conv2d(48, channel_num, 3, padding=1),
nn.GroupNorm(num_groups=8, num_channels=channel_num),
nn.GELU(),
)
self.relu = nn.GELU()
def forward(self, x):
# TODO: forward
residual = x
x = self.conv_block1(x)
x = self.conv_block2(x)
x = x + residual
return x
class ASPP(nn.Module):
def __init__(self, image_dim=384, head=1):
super(ASPP, self).__init__()
self.image_dim = image_dim
self.Residual2 = BasicBlock(channel_num=head)
self.pixel_shuffle = nn.PixelShuffle(2)
self.head = head
def forward(self, x):
x21 = F.interpolate(x, size=(self.image_dim, self.image_dim), mode='bilinear',
align_corners=True)
return x21
class Transformer_Regression(nn.Module):
def __init__(self, image_dim=224, dim_patch=24, num_classes=3, scale=1, feat_dim=192):
super(Transformer_Regression, self).__init__()
self.backbone = swin_base_patch4_window12_384_in22k(pretrained=True)
self.aux = 1
self.dim_patch = dim_patch
self.image_dim = image_dim
self.num_classes = num_classes
self.ASPP1 = ASPP(image_dim, head=128)
self.ASPP2 = ASPP(image_dim, head=128)
# self.ASPP3=ASPP(image_dim,scale,feat_dim)
self.feat_dim = feat_dim
# self.scale=1
self.Classifier_main = nn.Sequential(
# nn.Dropout(0.1),
nn.Conv2d(128, self.num_classes, 3, bias=True, padding=1),
)
self.Classifier_aux1 = nn.Sequential(
# nn.Dropout(0.1),
nn.Conv2d(128, self.num_classes, 3, bias=True, padding=1),
)
self.conv1 = nn.Sequential(nn.Conv2d(448, 128, kernel_size=(1, 1), padding=1), nn.GELU())
self.pixelshufler1 = nn.PixelShuffle(2)
self.pixelshufler2 = nn.PixelShuffle(4)
def forward(self, x):
hide1 = self.backbone(x)
x1 = []
x1.append((hide1[0][:, 0:].reshape(-1, 48, 48, 256)))
x1.append((hide1[1][:, 0:].reshape(-1, 24, 24, 512)))
x1.append((hide1[2][:, 0:].reshape(-1, 12, 12, 1024)))
for jk in range(len(x1)):
x1[jk] = x1[jk].permute(0, 3, 1, 2)
x1[1] = self.pixelshufler1(x1[1])
x1[2] = self.pixelshufler2(x1[2])
x1[0] = torch.cat((x1[0], x1[1], x1[2]), 1)
x1[0] = self.conv1(x1[0])
Score = dict()
x_main1 = self.ASPP1(x1[0])
x_main = self.Classifier_main(x_main1)
x_aux_1 = self.ASPP2(x1[0])
x_aux_1 = self.Classifier_aux1(x_aux_1) ####### x_aux_1
Score['seg'] = x_main
Score['seg_aux_1'] = x_aux_1
# Score['seg_aux_2'] = x_aux_2
return Score
Ratios = namedtuple("Ratios", 'cdr hcdr vcdr')
eps = np.finfo(np.float32).eps
def compute_ratios(mask_image):
'''
Given an input image containing the cup and disc masks the function returns
a tuple with the area, horizontal, and vertical cup-to-disc ratios
Input:
mask_image: an image with values (0,1,2) or (255,128,0)
for bg, disc, cup respectively
Output:
Ratios(cdr,hcdr,vcdr): a named tuple containing the computed ratios
'''
# if mask_image.max() == 2:
# make sure correct values are provided in the image
# if np.setdiff1d(np.unique(mask_image),np.array([0,1,2])).shape[0]>0:
# raise ValueError(('Mask values can only be (0,1,2) '
# 'or (255,128,0) for bg, disc, cup'))
# disc = np.uint8(mask_image > 0)
# cup = np.uint8(mask_image > 1)
# elif mask_image.max() == 255:
# # make sure correct values are provided in the image
# if np.setdiff1d(np.unique(mask_image),np.array([0,128,255])).shape[0]>0:
# raise ValueError(('Mask values can only be (0,1,2) '
# 'or (255,128,0) for bg, disc, cup'))
# disc = np.uint8(mask_image < 255)
# cup = np.uint8(mask_image == 0)
# else:
# raise ValueError(("Mask values can only be (0,1,2) or (255,128,0) "
# "for bg, disc, cup"))
# get the area
disc = 0
cup = 0
disc = disc + np.uint8(mask_image > 0)
cup = cup + np.uint8(mask_image > 1)
disc_area = np.sum(disc)
cup_area = np.sum(cup)
# get the vertical and horizontal mesure of the cup
cup_vert = np.sum(cup, axis=0).max().astype(np.int32)
cup_horz = np.sum(cup, axis=1).max().astype(np.int32)
# get the vertical and horizontal mesure of the disc
disc_vert = np.sum(disc, axis=0).max().astype(np.int32)
disc_horz = np.sum(disc, axis=1).max().astype(np.int32)
# calculate the cup to disc ratio
cdr = (cup_area + eps) / (disc_area + eps) # add eps to avoid div by 0
# calculate the horizontal and vertical cup to disc ration
hcdr = (cup_horz + eps) / (disc_horz + eps)
vcdr = (cup_vert + eps) / (disc_vert + eps)
return Ratios(cdr, hcdr, vcdr)