File size: 8,498 Bytes
0a2ce36
33ad8cf
0a2ce36
 
 
 
 
 
 
 
 
 
 
8ef1fbf
0a2ce36
 
 
 
 
8ef1fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a2ce36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67849dc
0a2ce36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67849dc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
#### This is an implmentation of deeplabv3 plus for retina detection
from skimage.measure import label, regionprops
import torch
import torchvision
from torch.nn import functional as F
import torch.nn as nn
import numpy as np
import cv2
import torch
from collections import namedtuple

# check you have the right version of timm
# assert timm.__version__ == "0.3.2"
from timm.models.swin_transformer import swin_base_patch4_window12_384_in22k, SwinTransformer

torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
pad_value = 10

def forward_features(self, x):
    x = self.patch_embed(x)
    if self.absolute_pos_embed is not None:
        x = x + self.absolute_pos_embed
    x = self.pos_drop(x)

    hide=[]
    for layer in self.layers:
        x = layer(x)
        #print(x.shape)
        hide.append(x)

    #x = self.layers(x)
    x = self.norm(x)  # B L C
    return hide

def forward(self, x):
        x = self.forward_features(x)
        #x = self.forward_head(x)
        return x

SwinTransformer.forward_features = forward_features
SwinTransformer.forward = forward




def extract_regions_Last(img_test, ytruth, pad1=pad_value, pad2=pad_value, pad3=pad_value, pad4=pad_value):

    y_truth_copy = ytruth.copy()
    y_truth_copy[y_truth_copy == 2] = 1
    label_img = label(y_truth_copy)

    regions = regionprops(label_img)
    max_Area = -1
    cropped_results = dict()
    for props in regions:
        if props.area > max_Area:
            max_Area = props.area
            minr, minc, maxr, maxc = props.bbox
            bx = (minc, maxc, maxc, minc, minc)
            by = (minr, minr, maxr, maxr, minr)
            # print(minr,maxr)
            # print(bx)
            # ax.plot(bx, by, '-b', linewidth=2.5)
            # cropped_image= pred_class[minr-pad:maxr+pad, minc-pad:maxc+pad]
            # cropped_pred_mask = pred_class[minr - pad:maxr + pad, minc - pad:maxc + pad]
            if minr - pad1 < 0:
                pad1 = 5
                if minr - pad1 < 0:
                    pad1 = 0

            if minc - pad2 < 0:
                pad2 = 5
                if minc - pad2 < 0:
                    pad2 = 0
            if maxr + pad3 > label_img.shape[0]:
                pad3 = 5
                if maxr + pad3 > label_img.shape[0]:
                    pad3 = 0

            if maxc + pad4 > label_img.shape[1]:
                pad4 = 5
                if maxc + pad4 > label_img.shape[1]:
                    pad4 = 0

            cropped_image = img_test[minr - pad1:maxr + pad3, minc - pad2:maxc + pad4, :]
            cropped_truth = ytruth[minr - pad1:maxr + pad3, minc - pad2:maxc + pad4]
            txcordi = []
            txcordi.append(minr - pad1)
            txcordi.append(maxr + pad3)
            txcordi.append(minc - pad2)
            txcordi.append(maxc + pad4)
            cropped_results['image'] = cropped_image
            cropped_results['truth'] = cropped_truth
            cropped_results['cord'] = txcordi

    return cropped_results


class BasicBlock(nn.Module):
    def __init__(self, channel_num):
        super(BasicBlock, self).__init__()
        # TODO: 3x3 convolution -> relu
        # the input and output channel number is channel_num
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(channel_num, 48, 1, padding=0),
            nn.GroupNorm(num_groups=8, num_channels=48),
            nn.GELU(),
        )
        self.conv_block2 = nn.Sequential(
            nn.Conv2d(48, channel_num, 3, padding=1),
            nn.GroupNorm(num_groups=8, num_channels=channel_num),
            nn.GELU(),
        )
        self.relu = nn.GELU()

    def forward(self, x):
        # TODO: forward
        residual = x
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = x + residual
        return x


class ASPP(nn.Module):
    def __init__(self, image_dim=384, head=1):
        super(ASPP, self).__init__()
        self.image_dim = image_dim
        self.Residual2 = BasicBlock(channel_num=head)
        self.pixel_shuffle = nn.PixelShuffle(2)
        self.head = head

    def forward(self, x):
        x21 = F.interpolate(x, size=(self.image_dim, self.image_dim), mode='bilinear',
                            align_corners=True)
        return x21



class Transformer_Regression(nn.Module):
    def __init__(self, image_dim=224, dim_patch=24, num_classes=3, scale=1, feat_dim=192):
        super(Transformer_Regression, self).__init__()
        self.backbone = swin_base_patch4_window12_384_in22k(pretrained=True)
        self.aux = 1
        self.dim_patch = dim_patch
        self.image_dim = image_dim
        self.num_classes = num_classes
        self.ASPP1 = ASPP(image_dim, head=128)
        self.ASPP2 = ASPP(image_dim, head=128)
        # self.ASPP3=ASPP(image_dim,scale,feat_dim)
        self.feat_dim = feat_dim
        # self.scale=1
        self.Classifier_main = nn.Sequential(
            # nn.Dropout(0.1),
            nn.Conv2d(128, self.num_classes, 3, bias=True, padding=1),
        )
        self.Classifier_aux1 = nn.Sequential(
            # nn.Dropout(0.1),
            nn.Conv2d(128, self.num_classes, 3, bias=True, padding=1),
        )

        self.conv1 = nn.Sequential(nn.Conv2d(448, 128, kernel_size=(1, 1), padding=1), nn.GELU())
        self.pixelshufler1 = nn.PixelShuffle(2)
        self.pixelshufler2 = nn.PixelShuffle(4)

    def forward(self, x):
        hide1 = self.backbone(x)
        x1 = []
        x1.append((hide1[0][:, 0:].reshape(-1, 48, 48, 256)))
        x1.append((hide1[1][:, 0:].reshape(-1, 24, 24, 512)))
        x1.append((hide1[2][:, 0:].reshape(-1, 12, 12, 1024)))
        for jk in range(len(x1)):
            x1[jk] = x1[jk].permute(0, 3, 1, 2)
        x1[1] = self.pixelshufler1(x1[1])
        x1[2] = self.pixelshufler2(x1[2])

        x1[0] = torch.cat((x1[0], x1[1], x1[2]), 1)

        x1[0] = self.conv1(x1[0])
        Score = dict()
        x_main1 = self.ASPP1(x1[0])
        x_main = self.Classifier_main(x_main1)
        x_aux_1 = self.ASPP2(x1[0])
        x_aux_1 = self.Classifier_aux1(x_aux_1)  ####### x_aux_1

        Score['seg'] = x_main
        Score['seg_aux_1'] = x_aux_1
        # Score['seg_aux_2'] = x_aux_2

        return Score


Ratios = namedtuple("Ratios", 'cdr hcdr vcdr')
eps = np.finfo(np.float32).eps


def compute_ratios(mask_image):
    '''
    Given an input image containing the cup and disc masks the function returns
    a tuple with the area, horizontal, and vertical cup-to-disc ratios
    Input:
        mask_image: an image with values (0,1,2) or (255,128,0)
                    for bg, disc, cup respectively
    Output:
        Ratios(cdr,hcdr,vcdr): a named tuple containing the computed ratios
    '''

    # if mask_image.max() == 2:
    # make sure correct values are provided in the image
    #    if np.setdiff1d(np.unique(mask_image),np.array([0,1,2])).shape[0]>0:
    #        raise ValueError(('Mask values can only be (0,1,2) '
    #                          'or (255,128,0) for bg, disc, cup'))
    #    disc = np.uint8(mask_image > 0)
    #    cup = np.uint8(mask_image > 1)
    # elif mask_image.max() == 255:
    #    # make sure correct values are provided in the image
    #    if np.setdiff1d(np.unique(mask_image),np.array([0,128,255])).shape[0]>0:
    #        raise ValueError(('Mask values can only be (0,1,2) '
    #                          'or (255,128,0) for bg, disc, cup'))
    #    disc = np.uint8(mask_image < 255)
    #    cup = np.uint8(mask_image == 0)
    # else:
    #    raise ValueError(("Mask values can only be (0,1,2) or (255,128,0) "
    #                      "for bg, disc, cup"))

    # get the area
    disc = 0
    cup = 0
    disc = disc + np.uint8(mask_image > 0)
    cup = cup + np.uint8(mask_image > 1)

    disc_area = np.sum(disc)
    cup_area = np.sum(cup)
    # get the vertical and horizontal mesure of the cup
    cup_vert = np.sum(cup, axis=0).max().astype(np.int32)
    cup_horz = np.sum(cup, axis=1).max().astype(np.int32)
    # get the vertical and horizontal mesure of the disc
    disc_vert = np.sum(disc, axis=0).max().astype(np.int32)
    disc_horz = np.sum(disc, axis=1).max().astype(np.int32)
    # calculate the cup to disc ratio
    cdr = (cup_area + eps) / (disc_area + eps)  # add eps to avoid div by 0
    # calculate the horizontal and vertical cup to disc ration
    hcdr = (cup_horz + eps) / (disc_horz + eps)
    vcdr = (cup_vert + eps) / (disc_vert + eps)

    return Ratios(cdr, hcdr, vcdr)