CAT-Seg / app.py
hsshin98
requirements
904905b
raw
history blame
3.01 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
import argparse
import glob
import multiprocessing as mp
import os
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
# fmt: off
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
# fmt: on
import tempfile
import time
import warnings
import cv2
import numpy as np
import tqdm
from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.utils.logger import setup_logger
from cat_seg import add_cat_seg_config
from demo.predictor import VisualizationDemo
import gradio as gr
from matplotlib.backends.backend_agg import FigureCanvasAgg as fc
# constants
WINDOW_NAME = "MaskFormer demo"
def setup_cfg(args):
# load config from file and command-line arguments
cfg = get_cfg()
add_deeplab_config(cfg)
add_cat_seg_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
return cfg
def get_parser():
parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
parser.add_argument(
"--config-file",
default="configs/vitl_swinb_384.yaml",
metavar="FILE",
help="path to config file",
)
parser.add_argument(
"--input",
nargs="+",
help="A list of space separated input images; "
"or a single glob pattern such as 'directory/*.jpg'",
)
parser.add_argument(
"--opts",
help="Modify config options using the command-line 'KEY VALUE' pairs",
default=["MODEL.WEIGHTS", "model_final.pth",
"MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON", "datasets/voc20.json",
"MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON", "datasets/voc20.json",
"TEST.SLIDING_WINDOW", "True",
"MODEL.SEM_SEG_HEAD.POOLING_SIZES", "[1,1]"],
nargs=argparse.REMAINDER,
)
return parser
def save_masks(preds, text):
preds = preds['sem_seg'].argmax(dim=0).cpu().numpy() # C H W
for i, t in enumerate(text):
dir = f"masks/mask_{t}.png"
mask = preds == i
cv2.imwrite(dir, mask * 255)
def predict(image, text):
args = get_parser().parse_args()
cfg = setup_cfg(args)
demo = VisualizationDemo(cfg, text=text)
predictions, visualized_output = demo.run_on_image(image)
save_masks(predictions, text.split(','))
canvas = fc(visualized_output.fig)
canvas.draw()
out = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(canvas.get_width_height()[::-1] + (3,))
return out[..., ::-1]
if __name__ == "__main__":
args = get_parser().parse_args()
cfg = setup_cfg(args)
iface = gr.Interface(
fn=predict,
inputs=[gr.Image(), gr.Textbox(placeholder="Classes to segment")],
outputs="image",
)
iface.launch()