hsshin98 commited on
Commit
e20de5f
1 Parent(s): aff8d56
Files changed (2) hide show
  1. app.py +3 -4
  2. demo/predictor.py +12 -1
app.py CHANGED
@@ -82,10 +82,7 @@ def save_masks(preds, text):
82
  cv2.imwrite(dir, mask * 255)
83
 
84
  def predict(image, text):
85
- args = get_parser().parse_args()
86
- cfg = setup_cfg(args)
87
- demo = VisualizationDemo(cfg, text=text)
88
- predictions, visualized_output = demo.run_on_image(image)
89
  #save_masks(predictions, text.split(','))
90
  canvas = fc(visualized_output.fig)
91
  canvas.draw()
@@ -96,6 +93,8 @@ def predict(image, text):
96
  if __name__ == "__main__":
97
  args = get_parser().parse_args()
98
  cfg = setup_cfg(args)
 
 
99
 
100
  iface = gr.Interface(
101
  fn=predict,
 
82
  cv2.imwrite(dir, mask * 255)
83
 
84
  def predict(image, text):
85
+ predictions, visualized_output = demo.run_on_image(image, text)
 
 
 
86
  #save_masks(predictions, text.split(','))
87
  canvas = fc(visualized_output.fig)
88
  canvas.draw()
 
93
  if __name__ == "__main__":
94
  args = get_parser().parse_args()
95
  cfg = setup_cfg(args)
96
+ global demo
97
+ demo = VisualizationDemo(cfg)
98
 
99
  iface = gr.Interface(
100
  fn=predict,
demo/predictor.py CHANGED
@@ -49,7 +49,7 @@ class VisualizationDemo(object):
49
  self.metadata = ns()
50
  self.metadata.stuff_classes = pred.test_class_texts
51
 
52
- def run_on_image(self, image):
53
  """
54
  Args:
55
  image (np.ndarray): an image of shape (H, W, C) (in BGR order).
@@ -59,6 +59,17 @@ class VisualizationDemo(object):
59
  vis_output (VisImage): the visualized image output.
60
  """
61
  vis_output = None
 
 
 
 
 
 
 
 
 
 
 
62
  predictions = self.predictor(image)
63
  # Convert image from OpenCV BGR format to Matplotlib RGB format.
64
  image = image[:, :, ::-1]
 
49
  self.metadata = ns()
50
  self.metadata.stuff_classes = pred.test_class_texts
51
 
52
+ def run_on_image(self, image, text=None):
53
  """
54
  Args:
55
  image (np.ndarray): an image of shape (H, W, C) (in BGR order).
 
59
  vis_output (VisImage): the visualized image output.
60
  """
61
  vis_output = None
62
+
63
+ if text is not None:
64
+ pred = self.predictor.model.sem_seg_head.predictor
65
+ pred.test_class_texts = text.split(',')
66
+ pred.text_features_test = pred.class_embeddings(pred.test_class_texts,
67
+ #imagenet_templates.IMAGENET_TEMPLATES,
68
+ ['A photo of a {} in the scene',],
69
+ pred.clip_model).permute(1, 0, 2).float().repeat(1, 80, 1)
70
+ self.metadata = ns()
71
+ self.metadata.stuff_classes = pred.test_class_texts
72
+
73
  predictions = self.predictor(image)
74
  # Convert image from OpenCV BGR format to Matplotlib RGB format.
75
  image = image[:, :, ::-1]