soumickmj commited on
Commit
ec30cc2
1 Parent(s): 3a0b935

Upload UNet3D

Browse files
Files changed (5) hide show
  1. UNetConfigs.py +28 -0
  2. UNets.py +25 -0
  3. config.json +15 -0
  4. model.safetensors +3 -0
  5. unet3d.py +295 -0
UNetConfigs.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+ class UNet3DConfig(PretrainedConfig):
5
+ model_type = "UNet"
6
+ def __init__(
7
+ self,
8
+ in_ch=1,
9
+ out_ch=1,
10
+ init_features=64,
11
+ **kwargs):
12
+ self.in_ch = in_ch
13
+ self.out_ch = out_ch
14
+ self.init_features = init_features
15
+ super().__init__(**kwargs)
16
+
17
+ class UNetMSS3DConfig(PretrainedConfig):
18
+ model_type = "UNetMSS"
19
+ def __init__(
20
+ self,
21
+ in_ch=1,
22
+ out_ch=1,
23
+ init_features=64,
24
+ **kwargs):
25
+ self.in_ch = in_ch
26
+ self.out_ch = out_ch
27
+ self.init_features = init_features
28
+ super().__init__(**kwargs)
UNets.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from .unet3d import U_Net, U_Net_DeepSup
3
+ from .UNetConfigs import UNet3DConfig, UNetMSS3DConfig
4
+
5
+ class UNet3D(PreTrainedModel):
6
+ config_class = UNet3DConfig
7
+ def __init__(self, config):
8
+ super().__init__(config)
9
+ self.model = U_Net(
10
+ in_ch=config.in_ch,
11
+ out_ch=config.out_ch,
12
+ init_features=config.init_features)
13
+ def forward(self, x):
14
+ return self.model(x)
15
+
16
+ class UNetMSS3D(PreTrainedModel):
17
+ config_class = UNetMSS3DConfig
18
+ def __init__(self, config):
19
+ super().__init__(config)
20
+ self.model = U_Net_DeepSup(
21
+ in_ch=config.in_ch,
22
+ out_ch=config.out_ch,
23
+ init_features=config.init_features)
24
+ def forward(self, x):
25
+ return self.model(x)
config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "UNet3D"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "UNetConfigs.UNet3DConfig",
7
+ "AutoModel": "UNets.UNet3D"
8
+ },
9
+ "in_ch": 1,
10
+ "init_features": 64,
11
+ "model_type": "UNet",
12
+ "out_ch": 1,
13
+ "torch_dtype": "float32",
14
+ "transformers_version": "4.44.2"
15
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dab15f665940de7226f311709b8452ef75639c2096a0b86ef0bb5fd3364d6ec2
3
+ size 414215700
unet3d.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # from __future__ import print_function, division
4
+ '''
5
+
6
+ Purpose :
7
+
8
+ '''
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.utils.data
14
+
15
+ __author__ = "Kartik Prabhu, Mahantesh Pattadkal, and Soumick Chatterjee"
16
+ __copyright__ = "Copyright 2020, Faculty of Computer Science, Otto von Guericke University Magdeburg, Germany"
17
+ __credits__ = ["Kartik Prabhu", "Mahantesh Pattadkal", "Soumick Chatterjee"]
18
+ __license__ = "GPL"
19
+ __version__ = "1.0.0"
20
+ __maintainer__ = "Soumick Chatterjee"
21
+ __email__ = "soumick.chatterjee@ovgu.de"
22
+ __status__ = "Production"
23
+
24
+ class conv_block(nn.Module):
25
+ """
26
+ Convolution Block
27
+ """
28
+
29
+ def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1, bias=True):
30
+ super(conv_block, self).__init__()
31
+ self.conv = nn.Sequential(
32
+ nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
33
+ stride=stride, padding=padding, bias=bias),
34
+ nn.BatchNorm3d(num_features=out_channels),
35
+ nn.ReLU(inplace=True),
36
+ nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=k_size,
37
+ stride=stride, padding=padding, bias=bias),
38
+ nn.BatchNorm3d(num_features=out_channels),
39
+ nn.ReLU(inplace=True)
40
+ )
41
+
42
+ def forward(self, x):
43
+ x = self.conv(x)
44
+ return x
45
+
46
+
47
+ class up_conv(nn.Module):
48
+ """
49
+ Up Convolution Block
50
+ """
51
+
52
+ # def __init__(self, in_ch, out_ch):
53
+ def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1, bias=True):
54
+ super(up_conv, self).__init__()
55
+ self.up = nn.Sequential(
56
+ nn.Upsample(scale_factor=2),
57
+ nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
58
+ stride=stride, padding=padding, bias=bias),
59
+ nn.BatchNorm3d(num_features=out_channels),
60
+ nn.ReLU(inplace=True))
61
+
62
+ def forward(self, x):
63
+ x = self.up(x)
64
+ return x
65
+
66
+
67
+ class U_Net(nn.Module):
68
+ """
69
+ UNet - Basic Implementation
70
+ Input _ [batch * channel(# of channels of each image) * depth(# of frames) * height * width].
71
+ Paper : https://arxiv.org/abs/1505.04597
72
+ """
73
+
74
+ def __init__(self, in_ch=1, out_ch=1, init_features=64):
75
+ super(U_Net, self).__init__()
76
+
77
+ n1 = init_features
78
+ filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] # 64,128,256,512,1024
79
+
80
+ self.Maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2)
81
+ self.Maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2)
82
+ self.Maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2)
83
+ self.Maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2)
84
+
85
+ self.Conv1 = conv_block(in_ch, filters[0])
86
+ self.Conv2 = conv_block(filters[0], filters[1])
87
+ self.Conv3 = conv_block(filters[1], filters[2])
88
+ self.Conv4 = conv_block(filters[2], filters[3])
89
+ self.Conv5 = conv_block(filters[3], filters[4])
90
+
91
+ self.Up5 = up_conv(filters[4], filters[3])
92
+ self.Up_conv5 = conv_block(filters[4], filters[3])
93
+
94
+ self.Up4 = up_conv(filters[3], filters[2])
95
+ self.Up_conv4 = conv_block(filters[3], filters[2])
96
+
97
+ self.Up3 = up_conv(filters[2], filters[1])
98
+ self.Up_conv3 = conv_block(filters[2], filters[1])
99
+
100
+ self.Up2 = up_conv(filters[1], filters[0])
101
+ self.Up_conv2 = conv_block(filters[1], filters[0])
102
+
103
+ self.Conv = nn.Conv3d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)
104
+
105
+ # self.active = torch.nn.Sigmoid()
106
+
107
+ def forward(self, x):
108
+ # print("unet")
109
+ # print(x.shape)
110
+ # print(padded.shape)
111
+
112
+ e1 = self.Conv1(x)
113
+ # print("conv1:")
114
+ # print(e1.shape)
115
+
116
+ e2 = self.Maxpool1(e1)
117
+ e2 = self.Conv2(e2)
118
+ # print("conv2:")
119
+ # print(e2.shape)
120
+
121
+ e3 = self.Maxpool2(e2)
122
+ e3 = self.Conv3(e3)
123
+ # print("conv3:")
124
+ # print(e3.shape)
125
+
126
+ e4 = self.Maxpool3(e3)
127
+ e4 = self.Conv4(e4)
128
+ # print("conv4:")
129
+ # print(e4.shape)
130
+
131
+ e5 = self.Maxpool4(e4)
132
+ e5 = self.Conv5(e5)
133
+ # print("conv5:")
134
+ # print(e5.shape)
135
+
136
+ d5 = self.Up5(e5)
137
+ # print("d5:")
138
+ # print(d5.shape)
139
+ # print("e4:")
140
+ # print(e4.shape)
141
+ d5 = torch.cat((e4, d5), dim=1)
142
+ d5 = self.Up_conv5(d5)
143
+ # print("upconv5:")
144
+ # print(d5.size)
145
+
146
+ d4 = self.Up4(d5)
147
+ # print("d4:")
148
+ # print(d4.shape)
149
+ d4 = torch.cat((e3, d4), dim=1)
150
+ d4 = self.Up_conv4(d4)
151
+ # print("upconv4:")
152
+ # print(d4.shape)
153
+ d3 = self.Up3(d4)
154
+ d3 = torch.cat((e2, d3), dim=1)
155
+ d3 = self.Up_conv3(d3)
156
+ # print("upconv3:")
157
+ # print(d3.shape)
158
+ d2 = self.Up2(d3)
159
+ d2 = torch.cat((e1, d2), dim=1)
160
+ d2 = self.Up_conv2(d2)
161
+ # print("upconv2:")
162
+ # print(d2.shape)
163
+ out = self.Conv(d2)
164
+ # print("out:")
165
+ # print(out.shape)
166
+ # d1 = self.active(out)
167
+
168
+ return [out]
169
+
170
+ class U_Net_DeepSup(nn.Module):
171
+ """
172
+ UNet - Basic Implementation
173
+ Input _ [batch * channel(# of channels of each image) * depth(# of frames) * height * width].
174
+ Paper : https://arxiv.org/abs/1505.04597
175
+ """
176
+
177
+ def __init__(self, in_ch=1, out_ch=1, init_features=64):
178
+ super(U_Net_DeepSup, self).__init__()
179
+
180
+ n1 = init_features
181
+ filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] # 64,128,256,512,1024
182
+
183
+ self.Maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2)
184
+ self.Maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2)
185
+ self.Maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2)
186
+ self.Maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2)
187
+
188
+ self.Conv1 = conv_block(in_ch, filters[0])
189
+ self.Conv2 = conv_block(filters[0], filters[1])
190
+ self.Conv3 = conv_block(filters[1], filters[2])
191
+ self.Conv4 = conv_block(filters[2], filters[3])
192
+ self.Conv5 = conv_block(filters[3], filters[4])
193
+
194
+ #1x1x1 Convolution for Deep Supervision
195
+ self.Conv_d3 = conv_block(filters[1], 1)
196
+ self.Conv_d4 = conv_block(filters[2], 1)
197
+
198
+
199
+
200
+ self.Up5 = up_conv(filters[4], filters[3])
201
+ self.Up_conv5 = conv_block(filters[4], filters[3])
202
+
203
+ self.Up4 = up_conv(filters[3], filters[2])
204
+ self.Up_conv4 = conv_block(filters[3], filters[2])
205
+
206
+ self.Up3 = up_conv(filters[2], filters[1])
207
+ self.Up_conv3 = conv_block(filters[2], filters[1])
208
+
209
+ self.Up2 = up_conv(filters[1], filters[0])
210
+ self.Up_conv2 = conv_block(filters[1], filters[0])
211
+
212
+ self.Conv = nn.Conv3d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)
213
+
214
+ for submodule in self.modules():
215
+ submodule.register_forward_hook(self.nan_hook)
216
+
217
+ # self.active = torch.nn.Sigmoid()
218
+
219
+ def nan_hook(self, module, inp, output):
220
+ for i, out in enumerate(output):
221
+ nan_mask = torch.isnan(out)
222
+ if nan_mask.any():
223
+ print("In", self.__class__.__name__)
224
+ print(module)
225
+ raise RuntimeError(f"Found NAN in output {i} at indices: ", nan_mask.nonzero(), "where:", out[nan_mask.nonzero()[:, 0].unique(sorted=True)])
226
+
227
+ def forward(self, x):
228
+ # print("unet")
229
+ # print(x.shape)
230
+ # print(padded.shape)
231
+
232
+ e1 = self.Conv1(x)
233
+ # print("conv1:")
234
+ # print(e1.shape)
235
+
236
+ e2 = self.Maxpool1(e1)
237
+ e2 = self.Conv2(e2)
238
+ # print("conv2:")
239
+ # print(e2.shape)
240
+
241
+ e3 = self.Maxpool2(e2)
242
+ e3 = self.Conv3(e3)
243
+ # print("conv3:")
244
+ # print(e3.shape)
245
+
246
+ e4 = self.Maxpool3(e3)
247
+ e4 = self.Conv4(e4)
248
+ # print("conv4:")
249
+ # print(e4.shape)
250
+
251
+ e5 = self.Maxpool4(e4)
252
+ e5 = self.Conv5(e5)
253
+ # print("conv5:")
254
+ # print(e5.shape)
255
+
256
+ d5 = self.Up5(e5)
257
+ # print("d5:")
258
+ # print(d5.shape)
259
+ # print("e4:")
260
+ # print(e4.shape)
261
+ d5 = torch.cat((e4, d5), dim=1)
262
+ d5 = self.Up_conv5(d5)
263
+ # print("upconv5:")
264
+ # print(d5.size)
265
+
266
+ d4 = self.Up4(d5)
267
+ # print("d4:")
268
+ # print(d4.shape)
269
+ d4 = torch.cat((e3, d4), dim=1)
270
+ d4 = self.Up_conv4(d4)
271
+ d4_out = self.Conv_d4(d4)
272
+
273
+
274
+ # print("upconv4:")
275
+ # print(d4.shape)
276
+ d3 = self.Up3(d4)
277
+ d3 = torch.cat((e2, d3), dim=1)
278
+ d3 = self.Up_conv3(d3)
279
+ d3_out = self.Conv_d3(d3)
280
+
281
+ # print("upconv3:")
282
+ # print(d3.shape)
283
+ d2 = self.Up2(d3)
284
+ d2 = torch.cat((e1, d2), dim=1)
285
+ d2 = self.Up_conv2(d2)
286
+ # print("upconv2:")
287
+ # print(d2.shape)
288
+ out = self.Conv(d2)
289
+ # print("out:")
290
+ # print(out.shape)
291
+ # d1 = self.active(out)
292
+
293
+ return [out, d3_out , d4_out]
294
+
295
+