HaoFeng2019 commited on
Commit
240c20c
1 Parent(s): 2ade88a

Upload 12 files

Browse files
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import cv2
6
+ import os
7
+ from PIL import Image
8
+ import warnings
9
+ import gradio as gr
10
+
11
+ from model import DocGeoNet
12
+ from seg import U2NETP
13
+ import glob
14
+
15
+
16
+ warnings.filterwarnings('ignore')
17
+
18
+ class Net(nn.Module):
19
+ def __init__(self):
20
+ super(Net, self).__init__()
21
+ self.msk = U2NETP(3, 1)
22
+ self.DocTr = DocGeoNet()
23
+
24
+ def forward(self, x):
25
+ msk, _1,_2,_3,_4,_5,_6 = self.msk(x)
26
+ msk = (msk > 0.5).float()
27
+ x = msk * x
28
+
29
+ _, _, bm = self.DocTr(x)
30
+ bm = (2 * (bm / 255.) - 1) * 0.99
31
+
32
+ return bm
33
+
34
+ def reload_seg_model(model, path=""):
35
+ if not bool(path):
36
+ return model
37
+ else:
38
+ model_dict = model.state_dict()
39
+ pretrained_dict = torch.load(path, map_location='cpu')
40
+ pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
41
+ model_dict.update(pretrained_dict)
42
+ model.load_state_dict(model_dict)
43
+ return model
44
+
45
+ def reload_rec_model(model, path=""):
46
+ if not bool(path):
47
+ return model
48
+ else:
49
+ model_dict = model.state_dict()
50
+ pretrained_dict = torch.load(path, map_location='cpu')
51
+ pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
52
+ model_dict.update(pretrained_dict)
53
+ model.load_state_dict(model_dict)
54
+ return model
55
+
56
+ def rec(input_image):
57
+ seg_model_path = './model_pretrained/preprocess.pth'
58
+ rec_model_path = './model_pretrained/DocGeoNet.pth'
59
+
60
+ net = Net()
61
+ reload_rec_model(net.DocTr, rec_model_path)
62
+ reload_seg_model(net.msk, seg_model_path)
63
+ net.eval()
64
+
65
+ im_ori = np.array(input_image)[:, :, :3] / 255. # read image 0-255 to 0-1
66
+ h, w, _ = im_ori.shape
67
+ im = cv2.resize(im_ori, (256, 256))
68
+ im = im.transpose(2, 0, 1)
69
+ im = torch.from_numpy(im).float().unsqueeze(0)
70
+
71
+ with torch.no_grad():
72
+ bm = net(im)
73
+ bm = bm.cpu()
74
+
75
+ bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow
76
+ bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow
77
+ bm0 = cv2.blur(bm0, (3, 3))
78
+ bm1 = cv2.blur(bm1, (3, 3))
79
+ lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
80
+ out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
81
+ img_rec = ((out[0] * 255).permute(1, 2, 0).numpy())[:,:,::-1].astype(np.uint8)
82
+
83
+ # Convert from BGR to RGB
84
+ img_rec = cv2.cvtColor(img_rec, cv2.COLOR_BGR2RGB)
85
+ return Image.fromarray(img_rec)
86
+
87
+
88
+ demo_img_files = glob.glob('./distorted/*.[jJ][pP][gG]') + glob.glob('./distorted/*.[pP][nN][gG]')
89
+
90
+ # Gradio Interface
91
+ input_image = gr.inputs.Image()
92
+ output_image = gr.outputs.Image(type='pil')
93
+
94
+ iface = gr.Interface(fn=rec, inputs=input_image, outputs=output_image, title="DocGeoNet",examples=demo_img_files)
95
+
96
+ #iface.launch(server_port=8821, server_name="0.0.0.0")
97
+ iface.launch(server_port=8821, server_name="0.0.0.0")
distorted/42_2 copy.png ADDED
distorted/63_2 copy.png ADDED
extractor.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class ResidualBlock(nn.Module):
7
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8
+ super(ResidualBlock, self).__init__()
9
+
10
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12
+ self.relu = nn.ReLU(inplace=True)
13
+
14
+ num_groups = planes // 8
15
+
16
+ if norm_fn == 'group':
17
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19
+ if not stride == 1:
20
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21
+
22
+ elif norm_fn == 'batch':
23
+ self.norm1 = nn.BatchNorm2d(planes)
24
+ self.norm2 = nn.BatchNorm2d(planes)
25
+ if not stride == 1:
26
+ self.norm3 = nn.BatchNorm2d(planes)
27
+
28
+ elif norm_fn == 'instance':
29
+ self.norm1 = nn.InstanceNorm2d(planes)
30
+ self.norm2 = nn.InstanceNorm2d(planes)
31
+ if not stride == 1:
32
+ self.norm3 = nn.InstanceNorm2d(planes)
33
+
34
+ elif norm_fn == 'none':
35
+ self.norm1 = nn.Sequential()
36
+ self.norm2 = nn.Sequential()
37
+ if not stride == 1:
38
+ self.norm3 = nn.Sequential()
39
+
40
+ if stride == 1:
41
+ self.downsample = None
42
+
43
+ else:
44
+ self.downsample = nn.Sequential(
45
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46
+
47
+
48
+ def forward(self, x):
49
+ y = x
50
+ y = self.relu(self.norm1(self.conv1(y)))
51
+ y = self.relu(self.norm2(self.conv2(y)))
52
+
53
+ if self.downsample is not None:
54
+ x = self.downsample(x)
55
+
56
+ return self.relu(x+y)
57
+
58
+
59
+ class BasicEncoder(nn.Module):
60
+ def __init__(self, input_dim=128, output_dim=128, norm_fn='batch'):
61
+ super(BasicEncoder, self).__init__()
62
+ self.norm_fn = norm_fn
63
+
64
+ if self.norm_fn == 'group':
65
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
66
+
67
+ elif self.norm_fn == 'batch':
68
+ self.norm1 = nn.BatchNorm2d(64)
69
+
70
+ elif self.norm_fn == 'instance':
71
+ self.norm1 = nn.InstanceNorm2d(64)
72
+
73
+ elif self.norm_fn == 'none':
74
+ self.norm1 = nn.Sequential()
75
+
76
+ self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=7, stride=2, padding=3)
77
+ self.relu1 = nn.ReLU(inplace=True)
78
+
79
+ self.in_planes = 64
80
+ self.layer1 = self._make_layer(64, stride=1)
81
+ self.layer2 = self._make_layer(128, stride=2)
82
+ self.layer3 = self._make_layer(192, stride=2)
83
+
84
+ # output convolution
85
+ self.conv2 = nn.Conv2d(192, output_dim, kernel_size=1)
86
+
87
+ for m in self.modules():
88
+ if isinstance(m, nn.Conv2d):
89
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
90
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
91
+ if m.weight is not None:
92
+ nn.init.constant_(m.weight, 1)
93
+ if m.bias is not None:
94
+ nn.init.constant_(m.bias, 0)
95
+
96
+ def _make_layer(self, dim, stride=1):
97
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
98
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
99
+ layers = (layer1, layer2)
100
+
101
+ self.in_planes = dim
102
+ return nn.Sequential(*layers)
103
+
104
+ def forward(self, x):
105
+ x = self.conv1(x)
106
+ x = self.norm1(x)
107
+ x = self.relu1(x)
108
+
109
+ x = self.layer1(x)
110
+ x = self.layer2(x)
111
+ x = self.layer3(x)
112
+
113
+ x = self.conv2(x)
114
+
115
+ return x
inference.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import DocGeoNet
2
+ from seg import U2NETP
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import skimage.io as io
8
+ import numpy as np
9
+ import cv2
10
+ import os
11
+ from PIL import Image
12
+ import argparse
13
+ import warnings
14
+ warnings.filterwarnings('ignore')
15
+
16
+
17
+ class Net(nn.Module):
18
+ def __init__(self, opt):
19
+ super(Net, self).__init__()
20
+ self.msk = U2NETP(3, 1)
21
+ self.DocTr = DocGeoNet()
22
+
23
+ def forward(self, x):
24
+ msk, _1,_2,_3,_4,_5,_6 = self.msk(x)
25
+ msk = (msk > 0.5).float()
26
+ x = msk * x
27
+
28
+ _, _, bm = self.DocTr(x)
29
+ bm = (2 * (bm / 255.) - 1) * 0.99
30
+
31
+ return bm
32
+
33
+
34
+ def reload_seg_model(model, path=""):
35
+ if not bool(path):
36
+ return model
37
+ else:
38
+ model_dict = model.state_dict()
39
+ pretrained_dict = torch.load(path, map_location='cpu')
40
+ print(len(pretrained_dict.keys()))
41
+ pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
42
+ print(len(pretrained_dict.keys()))
43
+ model_dict.update(pretrained_dict)
44
+ model.load_state_dict(model_dict)
45
+
46
+ return model
47
+
48
+
49
+ def reload_rec_model(model, path=""):
50
+ if not bool(path):
51
+ return model
52
+ else:
53
+ model_dict = model.state_dict()
54
+ pretrained_dict = torch.load(path, map_location='cpu')
55
+ print(len(pretrained_dict.keys()))
56
+ pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
57
+ print(len(pretrained_dict.keys()))
58
+ model_dict.update(pretrained_dict)
59
+ model.load_state_dict(model_dict)
60
+
61
+ return model
62
+
63
+
64
+ def rec(seg_model_path, rec_model_path, distorrted_path, save_path, opt):
65
+ print(torch.__version__)
66
+
67
+ # distorted images list
68
+ img_list = sorted(os.listdir(distorrted_path))
69
+
70
+ # creat save path for rectified images
71
+ if not os.path.exists(save_path):
72
+ os.makedirs(save_path)
73
+
74
+ net = Net(opt)#.cuda()
75
+ print(get_parameter_number(net))
76
+
77
+ # reload rec model
78
+ reload_rec_model(net.DocTr, rec_model_path)
79
+ reload_seg_model(net.msk, opt.seg_model_path)
80
+
81
+ net.eval()
82
+
83
+ for img_path in img_list:
84
+ name = img_path.split('.')[-2] # image name
85
+ img_path = distorrted_path + img_path # image path
86
+
87
+ im_ori = np.array(Image.open(img_path))[:, :, :3] / 255. # read image 0-255 to 0-1
88
+ h, w, _ = im_ori.shape
89
+ im = cv2.resize(im_ori, (256, 256))
90
+ im = im.transpose(2, 0, 1)
91
+ im = torch.from_numpy(im).float().unsqueeze(0)
92
+
93
+ with torch.no_grad():
94
+ bm = net(im)
95
+ bm = bm.cpu()
96
+
97
+ # save rectified image
98
+ bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow
99
+ bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow
100
+ bm0 = cv2.blur(bm0, (3, 3))
101
+ bm1 = cv2.blur(bm1, (3, 3))
102
+ lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
103
+ out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
104
+ cv2.imwrite(save_path + name + '_rec' + '.png', ((out[0] * 255).permute(1, 2, 0).numpy())[:,:,::-1].astype(np.uint8))
105
+
106
+
107
+ def get_parameter_number(net):
108
+ total_num = sum(p.numel() for p in net.parameters())
109
+ trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
110
+ return {'Total': total_num, 'Trainable': trainable_num}
111
+
112
+
113
+ def main():
114
+ parser = argparse.ArgumentParser()
115
+ parser.add_argument('--seg_model_path', default='./model_pretrained/preprocess.pth')
116
+ parser.add_argument('--rec_model_path', default='./model_pretrained/DocGeoNet.pth')
117
+ parser.add_argument('--distorrted_path', default='./distorted/')
118
+ parser.add_argument('--save_path', default='./rec/')
119
+ opt = parser.parse_args()
120
+
121
+ rec(seg_model_path=opt.seg_model_path,
122
+ rec_model_path=opt.rec_model_path,
123
+ distorrted_path=opt.distorrted_path,
124
+ save_path=opt.save_path,
125
+ opt=opt)
126
+
127
+ if __name__ == "__main__":
128
+ main()
model.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from extractor import BasicEncoder
2
+ from position_encoding import build_position_encoding
3
+ from unet import U_Net_mini
4
+
5
+ import argparse
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn, Tensor
9
+ import torch.nn.functional as F
10
+ import copy
11
+ from typing import Optional
12
+
13
+
14
+ class attnLayer(nn.Module):
15
+ def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1,
16
+ activation="relu", normalize_before=False):
17
+ super().__init__()
18
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
19
+ self.multihead_attn_list = nn.ModuleList(
20
+ [copy.deepcopy(nn.MultiheadAttention(d_model, nhead, dropout=dropout)) for i in range(2)])
21
+ # Implementation of Feedforward model
22
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
23
+ self.dropout = nn.Dropout(dropout)
24
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
25
+
26
+ self.norm1 = nn.LayerNorm(d_model)
27
+ self.norm2_list = nn.ModuleList([copy.deepcopy(nn.LayerNorm(d_model)) for i in range(2)])
28
+
29
+ self.norm3 = nn.LayerNorm(d_model)
30
+ self.dropout1 = nn.Dropout(dropout)
31
+ self.dropout2_list = nn.ModuleList([copy.deepcopy(nn.Dropout(dropout)) for i in range(2)])
32
+ self.dropout3 = nn.Dropout(dropout)
33
+
34
+ self.activation = _get_activation_fn(activation)
35
+ self.normalize_before = normalize_before
36
+
37
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
38
+ return tensor if pos is None else tensor + pos
39
+
40
+ def forward_post(self, tgt, memory_list, tgt_mask=None, memory_mask=None,
41
+ tgt_key_padding_mask=None, memory_key_padding_mask=None,
42
+ pos=None, memory_pos=None):
43
+ q = k = self.with_pos_embed(tgt, pos)
44
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
45
+ key_padding_mask=tgt_key_padding_mask)[0]
46
+ tgt = tgt + self.dropout1(tgt2)
47
+ tgt = self.norm1(tgt)
48
+ for memory, multihead_attn, norm2, dropout2, m_pos in zip(memory_list, self.multihead_attn_list,
49
+ self.norm2_list, self.dropout2_list, memory_pos):
50
+ tgt2 = multihead_attn(query=self.with_pos_embed(tgt, pos),
51
+ key=self.with_pos_embed(memory, m_pos),
52
+ value=memory, attn_mask=memory_mask,
53
+ key_padding_mask=memory_key_padding_mask)[0]
54
+ tgt = tgt + dropout2(tgt2)
55
+ tgt = norm2(tgt)
56
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
57
+ tgt = tgt + self.dropout3(tgt2)
58
+ tgt = self.norm3(tgt)
59
+ return tgt
60
+
61
+ def forward_pre(self, tgt, memory, tgt_mask=None, memory_mask=None,
62
+ tgt_key_padding_mask=None, memory_key_padding_mask=None,
63
+ pos=None, memory_pos=None):
64
+ tgt2 = self.norm1(tgt)
65
+ q = k = self.with_pos_embed(tgt2, pos)
66
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
67
+ key_padding_mask=tgt_key_padding_mask)[0]
68
+ tgt = tgt + self.dropout1(tgt2)
69
+ tgt2 = self.norm2(tgt)
70
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, pos),
71
+ key=self.with_pos_embed(memory, memory_pos),
72
+ value=memory, attn_mask=memory_mask,
73
+ key_padding_mask=memory_key_padding_mask)[0]
74
+ tgt = tgt + self.dropout2(tgt2)
75
+ tgt2 = self.norm3(tgt)
76
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
77
+ tgt = tgt + self.dropout3(tgt2)
78
+ return tgt
79
+
80
+ def forward(self, tgt, memory_list, tgt_mask=None, memory_mask=None,
81
+ tgt_key_padding_mask=None, memory_key_padding_mask=None,
82
+ pos=None, memory_pos=None):
83
+ if self.normalize_before:
84
+ return self.forward_pre(tgt, memory_list, tgt_mask, memory_mask,
85
+ tgt_key_padding_mask, memory_key_padding_mask, pos, memory_pos)
86
+ return self.forward_post(tgt, memory_list, tgt_mask, memory_mask,
87
+ tgt_key_padding_mask, memory_key_padding_mask, pos, memory_pos)
88
+
89
+
90
+ def _get_clones(module, N):
91
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
92
+
93
+
94
+ def _get_activation_fn(activation):
95
+ """Return an activation function given a string"""
96
+ if activation == "relu":
97
+ return F.relu
98
+ if activation == "gelu":
99
+ return F.gelu
100
+ if activation == "glu":
101
+ return F.glu
102
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
103
+
104
+
105
+ class TransDecoder(nn.Module):
106
+ def __init__(self, num_attn_layers, hidden_dim=128):
107
+ super(TransDecoder, self).__init__()
108
+ attn_layer = attnLayer(hidden_dim)
109
+ self.layers = _get_clones(attn_layer, num_attn_layers)
110
+ self.position_embedding = build_position_encoding(hidden_dim)
111
+
112
+ def forward(self, imgf, query_embed):
113
+ pos = self.position_embedding(
114
+ torch.ones(imgf.shape[0], imgf.shape[2], imgf.shape[3]).bool()) # torch.Size([1, 128, 36, 36])
115
+
116
+ bs, c, h, w = imgf.shape
117
+ imgf = imgf.flatten(2).permute(2, 0, 1) # torch.Size([1296, 1, 256])
118
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
119
+ pos = pos.flatten(2).permute(2, 0, 1) # torch.Size([1296, 1, 256])
120
+
121
+ for layer in self.layers:
122
+ query_embed = layer(query_embed, [imgf], pos=pos, memory_pos=[pos, pos])
123
+ query_embed = query_embed.permute(1, 2, 0).reshape(bs, c, h, w)
124
+
125
+ return query_embed
126
+
127
+
128
+ class TransEncoder(nn.Module):
129
+ def __init__(self, num_attn_layers, hidden_dim=128):
130
+ super(TransEncoder, self).__init__()
131
+ attn_layer = attnLayer(hidden_dim)
132
+ self.layers = _get_clones(attn_layer, num_attn_layers)
133
+ self.position_embedding = build_position_encoding(hidden_dim)
134
+
135
+ def forward(self, imgf):
136
+ pos = self.position_embedding(
137
+ torch.ones(imgf.shape[0], imgf.shape[2], imgf.shape[3]).bool()) # torch.Size([1, 128, 36, 36])
138
+ bs, c, h, w = imgf.shape
139
+ imgf = imgf.flatten(2).permute(2, 0, 1)
140
+ pos = pos.flatten(2).permute(2, 0, 1)
141
+
142
+ for layer in self.layers:
143
+ imgf = layer(imgf, [imgf], pos=pos, memory_pos=[pos, pos])
144
+ imgf = imgf.permute(1, 2, 0).reshape(bs, c, h, w)
145
+
146
+ return imgf
147
+
148
+
149
+ class FlowHead(nn.Module):
150
+ def __init__(self, input_dim=128, hidden_dim=256, out_cha=2):
151
+ super(FlowHead, self).__init__()
152
+ self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
153
+ self.conv2 = nn.Conv2d(hidden_dim, out_cha, 3, padding=1)
154
+ self.relu = nn.ReLU(inplace=True)
155
+
156
+ def forward(self, x):
157
+ return self.conv2(self.relu(self.conv1(x)))
158
+
159
+
160
+ class UpdateBlock(nn.Module):
161
+ def __init__(self, hidden_dim=128):
162
+ super(UpdateBlock, self).__init__()
163
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
164
+ self.mask = nn.Sequential(
165
+ nn.Conv2d(hidden_dim, 256, 3, padding=1),
166
+ nn.ReLU(inplace=True),
167
+ nn.Conv2d(256, 64 * 9, 1, padding=0))
168
+
169
+ def forward(self, imgf, coords1):
170
+ mask = .25 * self.mask(imgf) # scale mask to balence gradients
171
+ dflow = self.flow_head(imgf)
172
+ coords1 = coords1 + dflow
173
+
174
+ return mask, coords1
175
+
176
+
177
+ def coords_grid(batch, ht, wd):
178
+ coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
179
+ coords = torch.stack(coords[::-1], dim=0).float()
180
+ return coords[None].repeat(batch, 1, 1, 1)
181
+
182
+
183
+ def upflow8(flow, mode='bilinear'):
184
+ new_size = (8 * flow.shape[2], 8 * flow.shape[3])
185
+ return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
186
+
187
+
188
+ class Up_block(nn.Module):
189
+ def __init__(self, hidden_dim=128, out_cha=3):
190
+ super(Up_block, self).__init__()
191
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=256, out_cha=out_cha)
192
+ self.acf = nn.Hardtanh(0, 1)
193
+
194
+ def forward(self, x):
195
+ x = self.flow_head(x)
196
+ x = upflow8(x)
197
+ x = self.acf(x)
198
+ return x
199
+
200
+
201
+ class DocGeoNet(nn.Module):
202
+ def __init__(self):
203
+ super(DocGeoNet, self).__init__()
204
+
205
+ self.hidden_dim = hdim = 128
206
+ self.imcnn = BasicEncoder(input_dim=3, output_dim=hdim, norm_fn='instance')
207
+
208
+ # uv
209
+ self.wc_encoder = TransEncoder(4, hidden_dim=hdim)
210
+ # uv tail
211
+ self.Up_block_wc = nn.Sequential(TransEncoder(2, hidden_dim=hdim),
212
+ Up_block(self.hidden_dim))
213
+
214
+ # text
215
+ self.text_encoder = U_Net_mini(3, 1)
216
+ self.textcnn = nn.Conv2d(128, 64, 3, 2, 1) # BasicEncoder(input_dim=32, output_dim=64, norm_fn='instance')
217
+
218
+ # 6
219
+ self.bm_encoder = TransEncoder(6, hidden_dim=hdim + 64)
220
+
221
+ # bm tail
222
+ self.update_block = UpdateBlock(self.hidden_dim + 64)
223
+
224
+ def initialize_flow(self, img):
225
+ N, C, H, W = img.shape
226
+ coodslar = coords_grid(N, H, W).to(img.device)
227
+ coords0 = coords_grid(N, H // 8, W // 8).to(img.device)
228
+ coords1 = coords_grid(N, H // 8, W // 8).to(img.device)
229
+
230
+ return coodslar, coords0, coords1
231
+
232
+ def upsample_flow(self, flow, mask):
233
+ N, _, H, W = flow.shape
234
+ mask = mask.view(N, 1, 9, 8, 8, H, W)
235
+ mask = torch.softmax(mask, dim=2)
236
+
237
+ up_flow = F.unfold(8 * flow, [3, 3], padding=1)
238
+ up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
239
+
240
+ up_flow = torch.sum(mask * up_flow, dim=2)
241
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
242
+
243
+ return up_flow.reshape(N, 2, 8 * H, 8 * W)
244
+
245
+ def forward(self, image1):
246
+ # wc
247
+ imfmap = self.imcnn(image1)
248
+ imfmap = torch.relu(imfmap)
249
+ wcfea = self.wc_encoder(imfmap)
250
+ wc_pred = self.Up_block_wc(wcfea)
251
+
252
+ # text
253
+ d4, text_pred = self.text_encoder(image1)
254
+ textfea = self.textcnn(d4)
255
+ fmap = torch.cat((wcfea, textfea), 1)
256
+
257
+ # bm encoder
258
+ fmap = self.bm_encoder(fmap)
259
+
260
+ # upsample
261
+ coodslar, coords0, coords1 = self.initialize_flow(image1)
262
+ coords1 = coords1.detach()
263
+ mask, coords1 = self.update_block(fmap, coords1)
264
+ flow_up = self.upsample_flow(coords1 - coords0, mask)
265
+ bm_up = coodslar + flow_up
266
+
267
+ return wc_pred, text_pred, bm_up
model_pretrained/DocGeoNet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:27d7a379a92b4fe5bb347d26ef37da7c9cffbfefb09fcd8705bc9beae26e6146
3
+ size 95196536
model_pretrained/preprocess.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb79fdec55a5ed435dc74d8112aa9285d8213bae475022f711c709744fb19dd4
3
+ size 4715923
position_encoding.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Various positional encodings for the transformer.
4
+ """
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+ from typing import List
9
+ from typing import Optional
10
+ from torch import Tensor
11
+
12
+
13
+ class NestedTensor(object):
14
+ def __init__(self, tensors, mask: Optional[Tensor]):
15
+ self.tensors = tensors
16
+ self.mask = mask
17
+
18
+ def to(self, device):
19
+ # type: (Device) -> NestedTensor # noqa
20
+ cast_tensor = self.tensors.to(device)
21
+ mask = self.mask
22
+ if mask is not None:
23
+ assert mask is not None
24
+ cast_mask = mask.to(device)
25
+ else:
26
+ cast_mask = None
27
+ return NestedTensor(cast_tensor, cast_mask)
28
+
29
+ def decompose(self):
30
+ return self.tensors, self.mask
31
+
32
+ def __repr__(self):
33
+ return str(self.tensors)
34
+
35
+
36
+ class PositionEmbeddingSine(nn.Module):
37
+ """
38
+ This is a more standard version of the position embedding, very similar to the one
39
+ used by the Attention is all you need paper, generalized to work on images.
40
+ """
41
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
42
+ super().__init__()
43
+ self.num_pos_feats = num_pos_feats
44
+ self.temperature = temperature
45
+ self.normalize = normalize
46
+ if scale is not None and normalize is False:
47
+ raise ValueError("normalize should be True if scale is passed")
48
+ if scale is None:
49
+ scale = 2 * math.pi
50
+ self.scale = scale
51
+
52
+ def forward(self, mask):
53
+ assert mask is not None
54
+ y_embed = mask.cumsum(1, dtype=torch.float32)
55
+ x_embed = mask.cumsum(2, dtype=torch.float32)
56
+ if self.normalize:
57
+ eps = 1e-6
58
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
59
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
60
+
61
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32)#.cuda()
62
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
63
+
64
+ pos_x = x_embed[:, :, :, None] / dim_t
65
+ pos_y = y_embed[:, :, :, None] / dim_t
66
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
67
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
68
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
69
+ # print(pos.shape)
70
+ return pos
71
+
72
+
73
+ class PositionEmbeddingLearned(nn.Module):
74
+ """
75
+ Absolute pos embedding, learned.
76
+ """
77
+ def __init__(self, num_pos_feats=256):
78
+ super().__init__()
79
+ self.row_embed = nn.Embedding(50, num_pos_feats)
80
+ self.col_embed = nn.Embedding(50, num_pos_feats)
81
+ self.reset_parameters()
82
+
83
+ def reset_parameters(self):
84
+ nn.init.uniform_(self.row_embed.weight)
85
+ nn.init.uniform_(self.col_embed.weight)
86
+
87
+ def forward(self, tensor_list: NestedTensor):
88
+ x = tensor_list.tensors
89
+ h, w = x.shape[-2:]
90
+ i = torch.arange(w, device=x.device)
91
+ j = torch.arange(h, device=x.device)
92
+ x_emb = self.col_embed(i)
93
+ y_emb = self.row_embed(j)
94
+ pos = torch.cat([
95
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
96
+ y_emb.unsqueeze(1).repeat(1, w, 1),
97
+ ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
98
+ return pos
99
+
100
+ def build_position_encoding(hidden_dim=512, position_embedding='sine'):
101
+ N_steps = hidden_dim // 2
102
+ if position_embedding in ('v2', 'sine'):
103
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
104
+ elif position_embedding in ('v3', 'learned'):
105
+ position_embedding = PositionEmbeddingLearned(N_steps)
106
+ else:
107
+ raise ValueError(f"not supported {position_embedding}")
108
+
109
+ return position_embedding
110
+
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ numpy
2
+ opencv_python
3
+ Pillow
4
+ scikit_image
5
+ torch
6
+ torchvision
7
+ gradio
seg.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+
7
+
8
+ class sobel_net(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.conv_opx = nn.Conv2d(1, 1, 3, bias=False)
12
+ self.conv_opy = nn.Conv2d(1, 1, 3, bias=False)
13
+ sobel_kernelx = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype='float32').reshape((1, 1, 3, 3))
14
+ sobel_kernely = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype='float32').reshape((1, 1, 3, 3))
15
+ self.conv_opx.weight.data = torch.from_numpy(sobel_kernelx)
16
+ self.conv_opy.weight.data = torch.from_numpy(sobel_kernely)
17
+
18
+ for p in self.parameters():
19
+ p.requires_grad = False
20
+
21
+ def forward(self, im): # input rgb
22
+ x = (0.299 * im[:, 0, :, :] + 0.587 * im[:, 1, :, :] + 0.114 * im[:, 2, :, :]).unsqueeze(1) # rgb2gray
23
+ gradx = self.conv_opx(x)
24
+ grady = self.conv_opy(x)
25
+
26
+ x = (gradx ** 2 + grady ** 2) ** 0.5
27
+ x = (x - x.min()) / (x.max() - x.min())
28
+ x = F.pad(x, (1, 1, 1, 1))
29
+
30
+ x = torch.cat([im, x], dim=1)
31
+ return x
32
+
33
+
34
+ class REBNCONV(nn.Module):
35
+ def __init__(self, in_ch=3, out_ch=3, dirate=1):
36
+ super(REBNCONV, self).__init__()
37
+
38
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate)
39
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
40
+ self.relu_s1 = nn.ReLU(inplace=True)
41
+
42
+ def forward(self, x):
43
+ hx = x
44
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
45
+
46
+ return xout
47
+
48
+
49
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
50
+ def _upsample_like(src, tar):
51
+ src = F.interpolate(src, size=tar.shape[2:], mode='bilinear', align_corners=False)
52
+
53
+ return src
54
+
55
+
56
+ ### RSU-7 ###
57
+ class RSU7(nn.Module): # UNet07DRES(nn.Module):
58
+
59
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
60
+ super(RSU7, self).__init__()
61
+
62
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
63
+
64
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
65
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
66
+
67
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
68
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
69
+
70
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
71
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
72
+
73
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
74
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
75
+
76
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
77
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
78
+
79
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
80
+
81
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
82
+
83
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
84
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
85
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
86
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
87
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
88
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
89
+
90
+ def forward(self, x):
91
+ hx = x
92
+ hxin = self.rebnconvin(hx)
93
+
94
+ hx1 = self.rebnconv1(hxin)
95
+ hx = self.pool1(hx1)
96
+
97
+ hx2 = self.rebnconv2(hx)
98
+ hx = self.pool2(hx2)
99
+
100
+ hx3 = self.rebnconv3(hx)
101
+ hx = self.pool3(hx3)
102
+
103
+ hx4 = self.rebnconv4(hx)
104
+ hx = self.pool4(hx4)
105
+
106
+ hx5 = self.rebnconv5(hx)
107
+ hx = self.pool5(hx5)
108
+
109
+ hx6 = self.rebnconv6(hx)
110
+
111
+ hx7 = self.rebnconv7(hx6)
112
+
113
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
114
+ hx6dup = _upsample_like(hx6d, hx5)
115
+
116
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
117
+ hx5dup = _upsample_like(hx5d, hx4)
118
+
119
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
120
+ hx4dup = _upsample_like(hx4d, hx3)
121
+
122
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
123
+ hx3dup = _upsample_like(hx3d, hx2)
124
+
125
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
126
+ hx2dup = _upsample_like(hx2d, hx1)
127
+
128
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
129
+
130
+ return hx1d + hxin
131
+
132
+
133
+ ### RSU-6 ###
134
+ class RSU6(nn.Module): # UNet06DRES(nn.Module):
135
+
136
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
137
+ super(RSU6, self).__init__()
138
+
139
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
140
+
141
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
142
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
143
+
144
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
145
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
146
+
147
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
148
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
149
+
150
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
151
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
152
+
153
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
154
+
155
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
156
+
157
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
158
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
159
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
160
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
161
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
162
+
163
+ def forward(self, x):
164
+ hx = x
165
+
166
+ hxin = self.rebnconvin(hx)
167
+
168
+ hx1 = self.rebnconv1(hxin)
169
+ hx = self.pool1(hx1)
170
+
171
+ hx2 = self.rebnconv2(hx)
172
+ hx = self.pool2(hx2)
173
+
174
+ hx3 = self.rebnconv3(hx)
175
+ hx = self.pool3(hx3)
176
+
177
+ hx4 = self.rebnconv4(hx)
178
+ hx = self.pool4(hx4)
179
+
180
+ hx5 = self.rebnconv5(hx)
181
+
182
+ hx6 = self.rebnconv6(hx5)
183
+
184
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
185
+ hx5dup = _upsample_like(hx5d, hx4)
186
+
187
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
188
+ hx4dup = _upsample_like(hx4d, hx3)
189
+
190
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
191
+ hx3dup = _upsample_like(hx3d, hx2)
192
+
193
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
194
+ hx2dup = _upsample_like(hx2d, hx1)
195
+
196
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
197
+
198
+ return hx1d + hxin
199
+
200
+
201
+ ### RSU-5 ###
202
+ class RSU5(nn.Module): # UNet05DRES(nn.Module):
203
+
204
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
205
+ super(RSU5, self).__init__()
206
+
207
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
208
+
209
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
210
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
211
+
212
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
213
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
214
+
215
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
216
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
217
+
218
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
219
+
220
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
221
+
222
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
223
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
224
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
225
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
226
+
227
+ def forward(self, x):
228
+ hx = x
229
+
230
+ hxin = self.rebnconvin(hx)
231
+
232
+ hx1 = self.rebnconv1(hxin)
233
+ hx = self.pool1(hx1)
234
+
235
+ hx2 = self.rebnconv2(hx)
236
+ hx = self.pool2(hx2)
237
+
238
+ hx3 = self.rebnconv3(hx)
239
+ hx = self.pool3(hx3)
240
+
241
+ hx4 = self.rebnconv4(hx)
242
+
243
+ hx5 = self.rebnconv5(hx4)
244
+
245
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
246
+ hx4dup = _upsample_like(hx4d, hx3)
247
+
248
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
249
+ hx3dup = _upsample_like(hx3d, hx2)
250
+
251
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
252
+ hx2dup = _upsample_like(hx2d, hx1)
253
+
254
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
255
+
256
+ return hx1d + hxin
257
+
258
+
259
+ ### RSU-4 ###
260
+ class RSU4(nn.Module): # UNet04DRES(nn.Module):
261
+
262
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
263
+ super(RSU4, self).__init__()
264
+
265
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
266
+
267
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
268
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
269
+
270
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
271
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
272
+
273
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
274
+
275
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
276
+
277
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
278
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
279
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
280
+
281
+ def forward(self, x):
282
+ hx = x
283
+
284
+ hxin = self.rebnconvin(hx)
285
+
286
+ hx1 = self.rebnconv1(hxin)
287
+ hx = self.pool1(hx1)
288
+
289
+ hx2 = self.rebnconv2(hx)
290
+ hx = self.pool2(hx2)
291
+
292
+ hx3 = self.rebnconv3(hx)
293
+
294
+ hx4 = self.rebnconv4(hx3)
295
+
296
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
297
+ hx3dup = _upsample_like(hx3d, hx2)
298
+
299
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
300
+ hx2dup = _upsample_like(hx2d, hx1)
301
+
302
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
303
+
304
+ return hx1d + hxin
305
+
306
+
307
+ ### RSU-4F ###
308
+ class RSU4F(nn.Module): # UNet04FRES(nn.Module):
309
+
310
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
311
+ super(RSU4F, self).__init__()
312
+
313
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
314
+
315
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
316
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
317
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
318
+
319
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
320
+
321
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
322
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
323
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
324
+
325
+ def forward(self, x):
326
+ hx = x
327
+
328
+ hxin = self.rebnconvin(hx)
329
+
330
+ hx1 = self.rebnconv1(hxin)
331
+ hx2 = self.rebnconv2(hx1)
332
+ hx3 = self.rebnconv3(hx2)
333
+
334
+ hx4 = self.rebnconv4(hx3)
335
+
336
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
337
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
338
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
339
+
340
+ return hx1d + hxin
341
+
342
+
343
+ ##### U^2-Net ####
344
+ class U2NET(nn.Module):
345
+
346
+ def __init__(self, in_ch=3, out_ch=1):
347
+ super(U2NET, self).__init__()
348
+ self.edge = sobel_net()
349
+
350
+ self.stage1 = RSU7(in_ch, 32, 64)
351
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
352
+
353
+ self.stage2 = RSU6(64, 32, 128)
354
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
355
+
356
+ self.stage3 = RSU5(128, 64, 256)
357
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
358
+
359
+ self.stage4 = RSU4(256, 128, 512)
360
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
361
+
362
+ self.stage5 = RSU4F(512, 256, 512)
363
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
364
+
365
+ self.stage6 = RSU4F(512, 256, 512)
366
+
367
+ # decoder
368
+ self.stage5d = RSU4F(1024, 256, 512)
369
+ self.stage4d = RSU4(1024, 128, 256)
370
+ self.stage3d = RSU5(512, 64, 128)
371
+ self.stage2d = RSU6(256, 32, 64)
372
+ self.stage1d = RSU7(128, 16, 64)
373
+
374
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
375
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
376
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
377
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
378
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
379
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
380
+
381
+ self.outconv = nn.Conv2d(6, out_ch, 1)
382
+
383
+ def forward(self, x):
384
+ x = self.edge(x)
385
+ hx = x
386
+
387
+ # stage 1
388
+ hx1 = self.stage1(hx)
389
+ hx = self.pool12(hx1)
390
+
391
+ # stage 2
392
+ hx2 = self.stage2(hx)
393
+ hx = self.pool23(hx2)
394
+
395
+ # stage 3
396
+ hx3 = self.stage3(hx)
397
+ hx = self.pool34(hx3)
398
+
399
+ # stage 4
400
+ hx4 = self.stage4(hx)
401
+ hx = self.pool45(hx4)
402
+
403
+ # stage 5
404
+ hx5 = self.stage5(hx)
405
+ hx = self.pool56(hx5)
406
+
407
+ # stage 6
408
+ hx6 = self.stage6(hx)
409
+ hx6up = _upsample_like(hx6, hx5)
410
+
411
+ # -------------------- decoder --------------------
412
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
413
+ hx5dup = _upsample_like(hx5d, hx4)
414
+
415
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
416
+ hx4dup = _upsample_like(hx4d, hx3)
417
+
418
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
419
+ hx3dup = _upsample_like(hx3d, hx2)
420
+
421
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
422
+ hx2dup = _upsample_like(hx2d, hx1)
423
+
424
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
425
+
426
+ # side output
427
+ d1 = self.side1(hx1d)
428
+
429
+ d2 = self.side2(hx2d)
430
+ d2 = _upsample_like(d2, d1)
431
+
432
+ d3 = self.side3(hx3d)
433
+ d3 = _upsample_like(d3, d1)
434
+
435
+ d4 = self.side4(hx4d)
436
+ d4 = _upsample_like(d4, d1)
437
+
438
+ d5 = self.side5(hx5d)
439
+ d5 = _upsample_like(d5, d1)
440
+
441
+ d6 = self.side6(hx6)
442
+ d6 = _upsample_like(d6, d1)
443
+
444
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
445
+
446
+ return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(
447
+ d4), torch.sigmoid(d5), torch.sigmoid(d6)
448
+
449
+
450
+ ### U^2-Net small ###
451
+ class U2NETP(nn.Module):
452
+
453
+ def __init__(self, in_ch=3, out_ch=1):
454
+ super(U2NETP, self).__init__()
455
+
456
+ self.stage1 = RSU7(in_ch, 16, 64)
457
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
458
+
459
+ self.stage2 = RSU6(64, 16, 64)
460
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
461
+
462
+ self.stage3 = RSU5(64, 16, 64)
463
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
464
+
465
+ self.stage4 = RSU4(64, 16, 64)
466
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
467
+
468
+ self.stage5 = RSU4F(64, 16, 64)
469
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
470
+
471
+ self.stage6 = RSU4F(64, 16, 64)
472
+
473
+ # decoder
474
+ self.stage5d = RSU4F(128, 16, 64)
475
+ self.stage4d = RSU4(128, 16, 64)
476
+ self.stage3d = RSU5(128, 16, 64)
477
+ self.stage2d = RSU6(128, 16, 64)
478
+ self.stage1d = RSU7(128, 16, 64)
479
+
480
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
481
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
482
+ self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
483
+ self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
484
+ self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
485
+ self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
486
+
487
+ self.outconv = nn.Conv2d(6, out_ch, 1)
488
+
489
+ def forward(self, x):
490
+ hx = x
491
+
492
+ # stage 1
493
+ hx1 = self.stage1(hx)
494
+ hx = self.pool12(hx1)
495
+
496
+ # stage 2
497
+ hx2 = self.stage2(hx)
498
+ hx = self.pool23(hx2)
499
+
500
+ # stage 3
501
+ hx3 = self.stage3(hx)
502
+ hx = self.pool34(hx3)
503
+
504
+ # stage 4
505
+ hx4 = self.stage4(hx)
506
+ hx = self.pool45(hx4)
507
+
508
+ # stage 5
509
+ hx5 = self.stage5(hx)
510
+ hx = self.pool56(hx5)
511
+
512
+ # stage 6
513
+ hx6 = self.stage6(hx)
514
+ hx6up = _upsample_like(hx6, hx5)
515
+
516
+ # decoder
517
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
518
+ hx5dup = _upsample_like(hx5d, hx4)
519
+
520
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
521
+ hx4dup = _upsample_like(hx4d, hx3)
522
+
523
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
524
+ hx3dup = _upsample_like(hx3d, hx2)
525
+
526
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
527
+ hx2dup = _upsample_like(hx2d, hx1)
528
+
529
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
530
+
531
+ # side output
532
+ d1 = self.side1(hx1d)
533
+
534
+ d2 = self.side2(hx2d)
535
+ d2 = _upsample_like(d2, d1)
536
+
537
+ d3 = self.side3(hx3d)
538
+ d3 = _upsample_like(d3, d1)
539
+
540
+ d4 = self.side4(hx4d)
541
+ d4 = _upsample_like(d4, d1)
542
+
543
+ d5 = self.side5(hx5d)
544
+ d5 = _upsample_like(d5, d1)
545
+
546
+ d6 = self.side6(hx6)
547
+ d6 = _upsample_like(d6, d1)
548
+
549
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
550
+
551
+ return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(
552
+ d4), torch.sigmoid(d5), torch.sigmoid(d6)
553
+
554
+
555
+ def get_parameter_number(net):
556
+ total_num = sum(p.numel() for p in net.parameters())
557
+ trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
558
+ return {'Total': total_num, 'Trainable': trainable_num}
559
+
560
+
561
+ if __name__ == '__main__':
562
+ net = U2NET(4, 1)#.cuda()
563
+ print(get_parameter_number(net)) # 69090500 加attention后69442032
564
+ with torch.no_grad():
565
+ inputs = torch.zeros(1, 3, 256, 256)#.cuda()
566
+ outs = net(inputs)
567
+ print(outs[0].shape) # torch.Size([2, 3, 256, 256]) torch.Size([2, 2, 256, 256])
unet.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class sobel_net(nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.conv_opx = nn.Conv2d(1, 1, 3, bias=False)
11
+ self.conv_opy = nn.Conv2d(1, 1, 3, bias=False)
12
+ sobel_kernelx = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype='float32').reshape((1, 1, 3, 3))
13
+ sobel_kernely = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype='float32').reshape((1, 1, 3, 3))
14
+ self.conv_opx.weight.data = torch.from_numpy(sobel_kernelx)
15
+ self.conv_opy.weight.data = torch.from_numpy(sobel_kernely)
16
+
17
+ for p in self.parameters():
18
+ p.requires_grad = False
19
+
20
+ def forward(self, im): # input rgb
21
+ x = (0.299 * im[:, 0, :, :] + 0.587 * im[:, 1, :, :] + 0.114 * im[:, 2, :, :]).unsqueeze(1) # rgb2gray
22
+ gradx = self.conv_opx(x)
23
+ grady = self.conv_opy(x)
24
+
25
+ x = (gradx ** 2 + grady ** 2) ** 0.5
26
+ x = (x - x.min()) / (x.max() - x.min())
27
+ x = F.pad(x, (1, 1, 1, 1))
28
+
29
+ x = torch.cat([im, x], dim=1)
30
+ return x
31
+
32
+
33
+ class conv_block(nn.Module):
34
+ def __init__(self, ch_in, ch_out):
35
+ super(conv_block, self).__init__()
36
+ self.conv = nn.Sequential(
37
+ nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
38
+ nn.BatchNorm2d(ch_out),
39
+ nn.ReLU(inplace=True),
40
+ nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
41
+ nn.BatchNorm2d(ch_out),
42
+ nn.ReLU(inplace=True)
43
+ )
44
+
45
+ def forward(self, x):
46
+ x = self.conv(x)
47
+ return x
48
+
49
+
50
+ class up_conv(nn.Module):
51
+ def __init__(self, ch_in, ch_out):
52
+ super(up_conv, self).__init__()
53
+ self.up = nn.Sequential(
54
+ nn.Upsample(scale_factor=2),
55
+ nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
56
+ nn.BatchNorm2d(ch_out),
57
+ nn.ReLU(inplace=True)
58
+ )
59
+
60
+ def forward(self, x):
61
+ x = self.up(x)
62
+ return x
63
+
64
+
65
+ class Recurrent_block(nn.Module):
66
+ def __init__(self, ch_out, t=2):
67
+ super(Recurrent_block, self).__init__()
68
+ self.t = t
69
+ self.ch_out = ch_out
70
+ self.conv = nn.Sequential(
71
+ nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
72
+ nn.BatchNorm2d(ch_out),
73
+ nn.ReLU(inplace=True)
74
+ )
75
+
76
+ def forward(self, x):
77
+ for i in range(self.t):
78
+
79
+ if i == 0:
80
+ x1 = self.conv(x)
81
+
82
+ x1 = self.conv(x + x1)
83
+ return x1
84
+
85
+
86
+ class RRCNN_block(nn.Module):
87
+ def __init__(self, ch_in, ch_out, t=2):
88
+ super(RRCNN_block, self).__init__()
89
+ self.RCNN = nn.Sequential(
90
+ Recurrent_block(ch_out, t=t),
91
+ Recurrent_block(ch_out, t=t)
92
+ )
93
+ self.Conv_1x1 = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0)
94
+
95
+ def forward(self, x):
96
+ x = self.Conv_1x1(x)
97
+ x1 = self.RCNN(x)
98
+ return x + x1
99
+
100
+
101
+ class single_conv(nn.Module):
102
+ def __init__(self, ch_in, ch_out):
103
+ super(single_conv, self).__init__()
104
+ self.conv = nn.Sequential(
105
+ nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
106
+ nn.BatchNorm2d(ch_out),
107
+ nn.ReLU(inplace=True)
108
+ )
109
+
110
+ def forward(self, x):
111
+ x = self.conv(x)
112
+ return x
113
+
114
+
115
+ class Attention_block(nn.Module):
116
+ def __init__(self, F_g, F_l, F_int):
117
+ super(Attention_block, self).__init__()
118
+ self.W_g = nn.Sequential(
119
+ nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
120
+ nn.BatchNorm2d(F_int)
121
+ )
122
+
123
+ self.W_x = nn.Sequential(
124
+ nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
125
+ nn.BatchNorm2d(F_int)
126
+ )
127
+
128
+ self.psi = nn.Sequential(
129
+ nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
130
+ nn.BatchNorm2d(1),
131
+ nn.Sigmoid()
132
+ )
133
+
134
+ self.relu = nn.ReLU(inplace=True)
135
+
136
+ def forward(self, g, x):
137
+ g1 = self.W_g(g)
138
+ x1 = self.W_x(x)
139
+ psi = self.relu(g1 + x1)
140
+ psi = self.psi(psi)
141
+
142
+ return x * psi
143
+
144
+
145
+ class U_Net(nn.Module):
146
+ def __init__(self, img_ch=3, output_ch=1):
147
+ super(U_Net, self).__init__()
148
+
149
+ self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
150
+
151
+ self.Conv1 = conv_block(ch_in=img_ch, ch_out=64)
152
+ self.Conv2 = conv_block(ch_in=64, ch_out=128)
153
+ self.Conv3 = conv_block(ch_in=128, ch_out=256)
154
+ self.Conv4 = conv_block(ch_in=256, ch_out=512)
155
+ self.Conv5 = conv_block(ch_in=512, ch_out=1024)
156
+
157
+ self.Up5 = up_conv(ch_in=1024, ch_out=512)
158
+ self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
159
+
160
+ self.Up4 = up_conv(ch_in=512, ch_out=256)
161
+ self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
162
+
163
+ self.Up3 = up_conv(ch_in=256, ch_out=128)
164
+ self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
165
+
166
+ self.Up2 = up_conv(ch_in=128, ch_out=64)
167
+ self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
168
+
169
+ self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0, bias=False)
170
+
171
+ def forward(self, x):
172
+ # encoding path
173
+ x1 = self.Conv1(x)
174
+
175
+ x2 = self.Maxpool(x1)
176
+ x2 = self.Conv2(x2)
177
+
178
+ x3 = self.Maxpool(x2)
179
+ x3 = self.Conv3(x3)
180
+
181
+ x4 = self.Maxpool(x3)
182
+ x4 = self.Conv4(x4)
183
+
184
+ x5 = self.Maxpool(x4)
185
+ x5 = self.Conv5(x5)
186
+
187
+ # decoding + concat path
188
+ d5 = self.Up5(x5)
189
+ d5 = torch.cat((x4, d5), dim=1)
190
+
191
+ d5 = self.Up_conv5(d5)
192
+
193
+ d4 = self.Up4(d5)
194
+ d4 = torch.cat((x3, d4), dim=1)
195
+ d4 = self.Up_conv4(d4)
196
+
197
+ d3 = self.Up3(d4)
198
+ d3 = torch.cat((x2, d3), dim=1)
199
+ d3 = self.Up_conv3(d3)
200
+
201
+ d2 = self.Up2(d3)
202
+ d2 = torch.cat((x1, d2), dim=1)
203
+ d2 = self.Up_conv2(d2)
204
+
205
+ out = self.Conv_1x1(d2)
206
+ out = torch.sigmoid(out)
207
+
208
+ return out
209
+
210
+
211
+ class U_Net_mini(nn.Module):
212
+ def __init__(self, img_ch=3, output_ch=1):
213
+ super(U_Net_mini, self).__init__()
214
+
215
+ self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
216
+
217
+ self.Conv1 = conv_block(ch_in=img_ch, ch_out=32)
218
+ self.Conv2 = conv_block(ch_in=32, ch_out=64)
219
+ self.Conv3 = conv_block(ch_in=64, ch_out=128)
220
+ self.Conv4 = conv_block(ch_in=128, ch_out=256)
221
+ self.Conv5 = conv_block(ch_in=256, ch_out=512)
222
+
223
+ self.Up5 = up_conv(ch_in=512, ch_out=256)
224
+ self.Up_conv5 = conv_block(ch_in=512, ch_out=256)
225
+
226
+ self.Up4 = up_conv(ch_in=256, ch_out=128)
227
+ self.Up_conv4 = conv_block(ch_in=256, ch_out=128)
228
+
229
+ self.Up3 = up_conv(ch_in=128, ch_out=64)
230
+ self.Up_conv3 = conv_block(ch_in=128, ch_out=64)
231
+
232
+ self.Up2 = up_conv(ch_in=64, ch_out=32)
233
+ self.Up_conv2 = conv_block(ch_in=64, ch_out=32)
234
+
235
+ self.Conv_1x1 = nn.Conv2d(32, output_ch, kernel_size=1, stride=1, padding=0, bias=False)
236
+
237
+ def forward(self, x):
238
+ # encoding path
239
+ x1 = self.Conv1(x)
240
+
241
+ x2 = self.Maxpool(x1)
242
+ x2 = self.Conv2(x2)
243
+
244
+ x3 = self.Maxpool(x2)
245
+ x3 = self.Conv3(x3)
246
+
247
+ x4 = self.Maxpool(x3)
248
+ x4 = self.Conv4(x4)
249
+
250
+ x5 = self.Maxpool(x4)
251
+ x5 = self.Conv5(x5)
252
+
253
+ # decoding + concat path
254
+ d5 = self.Up5(x5)
255
+ d5 = torch.cat((x4, d5), dim=1)
256
+
257
+ d5 = self.Up_conv5(d5)
258
+
259
+ d4 = self.Up4(d5)
260
+ d4 = torch.cat((x3, d4), dim=1)
261
+ d4 = self.Up_conv4(d4)
262
+
263
+ d3 = self.Up3(d4)
264
+ d3 = torch.cat((x2, d3), dim=1)
265
+ d3 = self.Up_conv3(d3)
266
+
267
+ d2 = self.Up2(d3)
268
+ d2 = torch.cat((x1, d2), dim=1)
269
+ d2 = self.Up_conv2(d2)
270
+
271
+ out = self.Conv_1x1(d2)
272
+ out = torch.sigmoid(out)
273
+
274
+ return d4, out
275
+
276
+
277
+ class AttU_Net(nn.Module):
278
+ def __init__(self, img_ch=3, output_ch=1, need_feature_maps=False):
279
+ super(AttU_Net, self).__init__()
280
+
281
+ self.conv1_ = nn.Sequential(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
282
+ nn.BatchNorm2d(64),
283
+ nn.ReLU(inplace=True))
284
+
285
+ self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
286
+
287
+ self.Conv1 = conv_block(ch_in=64, ch_out=64)
288
+ self.Conv2 = conv_block(ch_in=64, ch_out=128)
289
+ self.Conv3 = conv_block(ch_in=128, ch_out=256)
290
+ self.Conv4 = conv_block(ch_in=256, ch_out=512)
291
+ self.Conv5 = conv_block(ch_in=512, ch_out=1024)
292
+
293
+ self.Up5 = up_conv(ch_in=1024, ch_out=512)
294
+ self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256)
295
+ self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
296
+
297
+ self.Up4 = up_conv(ch_in=512, ch_out=256)
298
+ self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128)
299
+ self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
300
+
301
+ self.Up3 = up_conv(ch_in=256, ch_out=128)
302
+ self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64)
303
+ self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
304
+
305
+ self.Up2 = up_conv(ch_in=128, ch_out=64)
306
+ self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32)
307
+ self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
308
+
309
+ self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)
310
+
311
+ self.need_feature_maps = need_feature_maps
312
+
313
+ # self.loc_xy = (torch.stack(torch.meshgrid([torch.arange(0, 256), torch.arange(0, 256)])).permute(0, 2, 1).unsqueeze(0).float() - 127.5) / 127.5 # 1*2*256*256
314
+
315
+ def forward(self, x):
316
+ # encoding path
317
+ # batch = x.size(0)
318
+ # if self.need_feature_maps:
319
+ # x = torch.cat((x, self.loc_xy.repeat(batch, 1, 1, 1).cuda()), dim=1)
320
+ x1 = self.conv1_(x)
321
+ x1 = self.Conv1(x1)
322
+
323
+ x2 = self.Maxpool(x1)
324
+ x2 = self.Conv2(x2)
325
+
326
+ x3 = self.Maxpool(x2)
327
+ x3 = self.Conv3(x3)
328
+
329
+ x4 = self.Maxpool(x3)
330
+ x4 = self.Conv4(x4)
331
+
332
+ x5 = self.Maxpool(x4)
333
+ x5 = self.Conv5(x5)
334
+
335
+ # decoding + concat path
336
+ d5 = self.Up5(x5)
337
+ x4 = self.Att5(g=d5, x=x4)
338
+ d5 = torch.cat((x4, d5), dim=1)
339
+ d5 = self.Up_conv5(d5)
340
+
341
+ d4 = self.Up4(d5)
342
+ x3 = self.Att4(g=d4, x=x3)
343
+ d4 = torch.cat((x3, d4), dim=1)
344
+ d4 = self.Up_conv4(d4)
345
+
346
+ d3 = self.Up3(d4)
347
+ x2 = self.Att3(g=d3, x=x2)
348
+ d3 = torch.cat((x2, d3), dim=1)
349
+ d3 = self.Up_conv3(d3)
350
+
351
+ d2 = self.Up2(d3)
352
+ x1 = self.Att2(g=d2, x=x1)
353
+ d2 = torch.cat((x1, d2), dim=1)
354
+ d2 = self.Up_conv2(d2)
355
+
356
+ wc = self.Conv_1x1(d2)
357
+
358
+ if self.need_feature_maps:
359
+ return d2, wc
360
+ else:
361
+ return bm
362
+
363
+
364
+ class Doc_UNet(nn.Module):
365
+ def __init__(self):
366
+ super(Doc_UNet, self).__init__()
367
+ self.U_net1 = AttU_Net(3, 3, need_feature_maps=True)
368
+ self.U_net2 = U_Net(64 + 3 + 2, 2, need_feature_maps=False)
369
+ self.htan = nn.Hardtanh(0, 1.0)
370
+ self.f_activation = nn.Hardtanh()
371
+
372
+ self.loc_xy = (torch.stack(torch.meshgrid([torch.arange(0, 128), torch.arange(0, 128)])).permute(0, 2,
373
+ 1).unsqueeze(
374
+ 0).float() - 63.5) / 63.5 # 1*2*256*256
375
+
376
+ def forward(self, x):
377
+ batch = x.size(0)
378
+
379
+ feature_maps, wc = self.U_net1(x)
380
+ wc = self.htan(wc)
381
+
382
+ x = torch.cat((self.loc_xy.repeat(batch, 1, 1, 1).cuda(), wc, feature_maps), dim=1)
383
+ bm = self.U_net2(x)
384
+ bm = self.f_activation(bm)
385
+
386
+ return wc, bm
387
+
388
+
389
+ def get_parameter_number(net):
390
+ total_num = sum(p.numel() for p in net.parameters())
391
+ trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
392
+ return {'Total': total_num, 'Trainable': trainable_num}
393
+
394
+
395
+ if __name__ == '__main__':
396
+ net = U2NET(3, 1).cuda()
397
+ print(get_parameter_number(net)) # 69090500 加attention后69442032
398
+ with torch.no_grad():
399
+ inputs = torch.zeros(1, 3, 256, 256).cuda()
400
+ outs = net(inputs)
401
+ print(outs[0].shape) # torch.Size([2, 3, 256, 256]) torch.Size([2, 2, 256, 256])