bigmed@bigmed commited on
Commit
8ef1fbf
1 Parent(s): 67849dc

monkey patched timm swintransformer model

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. pipline.py +27 -1
app.py CHANGED
@@ -106,7 +106,7 @@ def Final_Compute_regression_results_Sample(Model, batch_sampler,num_head=2):
106
  if ratios.vcdr < 0.6:
107
  glaucoma = 'None (Rule of thumb ~0.6 or greater is suspicious)'
108
  else:
109
- glaucoma = 'May be thre is a risk of Glaucoma (Rule of thumb ~0.6 or greater is suspicious)'
110
 
111
  # print('Galucoma:')
112
 
 
106
  if ratios.vcdr < 0.6:
107
  glaucoma = 'None (Rule of thumb ~0.6 or greater is suspicious)'
108
  else:
109
+ glaucoma = 'May be there is a risk of Glaucoma (Rule of thumb ~0.6 or greater is suspicious)'
110
 
111
  # print('Galucoma:')
112
 
pipline.py CHANGED
@@ -11,12 +11,38 @@ from collections import namedtuple
11
 
12
  # check you have the right version of timm
13
  # assert timm.__version__ == "0.3.2"
14
- from timm.models.swin_transformer import swin_base_patch4_window12_384_in22k
15
 
16
  torch.manual_seed(0)
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  pad_value = 10
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def extract_regions_Last(img_test, ytruth, pad1=pad_value, pad2=pad_value, pad3=pad_value, pad4=pad_value):
22
 
 
11
 
12
  # check you have the right version of timm
13
  # assert timm.__version__ == "0.3.2"
14
+ from timm.models.swin_transformer import swin_base_patch4_window12_384_in22k, SwinTransformer
15
 
16
  torch.manual_seed(0)
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  pad_value = 10
19
 
20
+ def forward_features(self, x):
21
+ x = self.patch_embed(x)
22
+ if self.absolute_pos_embed is not None:
23
+ x = x + self.absolute_pos_embed
24
+ x = self.pos_drop(x)
25
+
26
+ hide=[]
27
+ for layer in self.layers:
28
+ x = layer(x)
29
+ #print(x.shape)
30
+ hide.append(x)
31
+
32
+ #x = self.layers(x)
33
+ x = self.norm(x) # B L C
34
+ return hide
35
+
36
+ def forward(self, x):
37
+ x = self.forward_features(x)
38
+ #x = self.forward_head(x)
39
+ return x
40
+
41
+ SwinTransformer.forward_features = forward_features
42
+ SwinTransformer.forward = forward
43
+
44
+
45
+
46
 
47
  def extract_regions_Last(img_test, ytruth, pad1=pad_value, pad2=pad_value, pad3=pad_value, pad4=pad_value):
48