Hector Lopez commited on
Commit
e5bb367
1 Parent(s): 8123f86

feature: Use ViT as classifier

Browse files
Files changed (3) hide show
  1. app.py +4 -4
  2. classifier.py +56 -4
  3. model.py +10 -4
app.py CHANGED
@@ -5,14 +5,14 @@ import cv2
5
  import PIL
6
  import torch
7
 
8
- from classifier import CustomEfficientNet
9
  from model import get_model, predict, prepare_prediction, predict_class
10
 
11
  print('Creating the model')
12
- model = get_model('checkpoint.ckpt')
13
  print('Loading the classifier')
14
- classifier = CustomEfficientNet(target_size=7, pretrained=False)
15
- classifier.load_state_dict(torch.load('class_efficientB0_taco_7_class.pth', map_location='cpu'))
16
 
17
  def plot_img_no_mask(image, boxes, labels):
18
  colors = {
 
5
  import PIL
6
  import torch
7
 
8
+ from classifier import CustomEfficientNet, CustomViT
9
  from model import get_model, predict, prepare_prediction, predict_class
10
 
11
  print('Creating the model')
12
+ model = get_model('efficientDet_icevision.ckpt')
13
  print('Loading the classifier')
14
+ classifier = CustomViT(target_size=7, pretrained=False)
15
+ classifier.load_state_dict(torch.load('class_ViT_taco_7_class.pth', map_location='cpu'))
16
 
17
  def plot_img_no_mask(image, boxes, labels):
18
  colors = {
classifier.py CHANGED
@@ -1,12 +1,27 @@
1
  import timm
2
  import torch.nn as nn
3
-
4
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- def get_efficientnet(model_name):
7
- model = timm.create_model(model_name, pretrained=True)
8
 
9
- return model
10
 
11
  class CustomEfficientNet(nn.Module):
12
  """
@@ -43,3 +58,40 @@ class CustomEfficientNet(nn.Module):
43
  x = self.model(x)
44
 
45
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import timm
2
  import torch.nn as nn
3
+ import albumentations as A
4
  import torch
5
+ import cv2
6
+
7
+ class CustomNormalization(A.ImageOnlyTransform):
8
+ def _norm(self, img):
9
+ return img / 255.
10
+
11
+ def apply(self, img, **params):
12
+ return self._norm(img)
13
+
14
+ def transform_image(image, size):
15
+ transforms = [
16
+ A.Resize(size, size,
17
+ interpolation=cv2.INTER_NEAREST),
18
+ CustomNormalization(p=1),
19
+ ]
20
 
21
+ augs = A.Compose(transforms)
22
+ transformed = augs(image=image)
23
 
24
+ return transformed['image']
25
 
26
  class CustomEfficientNet(nn.Module):
27
  """
 
58
  x = self.model(x)
59
 
60
  return x
61
+
62
+ class CustomViT(nn.Module):
63
+ """
64
+ This class defines a custom ViT network.
65
+
66
+ Parameters
67
+ ----------
68
+ target_size : int
69
+ Number of units for the output layer.
70
+ pretrained : bool
71
+ Determine if pretrained weights are used.
72
+
73
+ Attributes
74
+ ----------
75
+ model : nn.Module
76
+ CustomViT model.
77
+ """
78
+ def __init__(self, model_name : str = 'vit_base_patch16_224',
79
+ target_size : int = 4, pretrained : bool = True):
80
+ super().__init__()
81
+ self.model = timm.create_model(model_name,
82
+ pretrained=pretrained,
83
+ num_classes=target_size)
84
+
85
+ in_features = self.model.head.in_features
86
+ self.model.head = nn.Sequential(
87
+ #nn.Dropout(0.5),
88
+ nn.Linear(in_features, 256),
89
+ nn.ReLU(),
90
+ nn.Dropout(0.5),
91
+ nn.Linear(256, target_size)
92
+ )
93
+
94
+ def forward(self, x : torch.Tensor) -> torch.Tensor:
95
+ x = self.model(x)
96
+
97
+ return x
model.py CHANGED
@@ -6,6 +6,8 @@ import torch
6
  import numpy as np
7
  import torchvision
8
 
 
 
9
  import icevision.models.ross.efficientdet
10
 
11
  MODEL_TYPE = icevision.models.ross.efficientdet
@@ -81,10 +83,14 @@ def predict_class(model, image, bboxes):
81
  img = image.copy()
82
  bbox = np.array(bbox).astype(int)
83
  cropped_img = PIL.Image.fromarray(img).crop(bbox)
84
- cropped_img = np.array(cropped_img).transpose(2, 0, 1)
85
- cropped_img = torch.as_tensor(cropped_img, dtype=torch.float).unsqueeze(0)
86
-
87
- y_preds = model(cropped_img)
 
 
 
 
88
  preds.append(y_preds.softmax(1).detach().numpy())
89
 
90
  preds = np.concatenate(preds).argmax(1)
 
6
  import numpy as np
7
  import torchvision
8
 
9
+ from classifier import transform_image
10
+
11
  import icevision.models.ross.efficientdet
12
 
13
  MODEL_TYPE = icevision.models.ross.efficientdet
 
83
  img = image.copy()
84
  bbox = np.array(bbox).astype(int)
85
  cropped_img = PIL.Image.fromarray(img).crop(bbox)
86
+ cropped_img = np.array(cropped_img)
87
+ #cropped_img = torch.as_tensor(cropped_img, dtype=torch.float).unsqueeze(0)
88
+
89
+ tran_image = transform_image(cropped_img, 224)
90
+ tran_image = tran_image.transpose(2, 0, 1)
91
+ tran_image = torch.as_tensor(tran_image, dtype=torch.float).unsqueeze(0)
92
+ print(tran_image.shape)
93
+ y_preds = model(tran_image)
94
  preds.append(y_preds.softmax(1).detach().numpy())
95
 
96
  preds = np.concatenate(preds).argmax(1)