gun_detection / detection.py
siddhantgore's picture
Upload 3 files
ac8f703
raw
history blame contribute delete
No virus
1.61 kB
from detectron2.engine import DefaultPredictor
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()
from detectron2.utils.video_visualizer import VideoVisualizer
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import ColorMode, Visualizer
from detectron2 import model_zoo
from detectron2.data.datasets import register_coco_instances
from PIL import Image
import PIL
import cv2
import numpy as np
import matplotlib.pyplot as plt
class Detector:
def __init__(self, model_type = "object_detection"):
self.cfg=get_cfg()
self.cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml")) # load the default configuration
self.cfg.MODEL.WEIGHTS = 'model_final.pth'
self.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.8
self.cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
self.cfg.MODEL.DEVICE="cpu"
dataset_name="guns"
classes=['guns','Gun']
MetadataCatalog.get(dataset_name).set(thing_classes=classes)
self.predictor = DefaultPredictor(self.cfg)
def onImage(self, imagePath):
image = cv2.imread(imagePath)
predictions = self.predictor(image)
dataset_name="guns"
viz = Visualizer(image,MetadataCatalog.get(dataset_name),scale=1)
output = viz.draw_instance_predictions(predictions['instances'].to('cpu'))
filename = 'result.jpg'
cv2.imwrite(filename, output.get_image()[:,:,::-1])
# cv2.waitKey(0)
# cv2.destroyAllWindows()