File size: 3,006 Bytes
d617811
 
 
 
 
 
904905b
d617811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
import argparse
import glob
import multiprocessing as mp
import os
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')

# fmt: off
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
# fmt: on

import tempfile
import time
import warnings

import cv2
import numpy as np
import tqdm

from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.utils.logger import setup_logger

from cat_seg import add_cat_seg_config
from demo.predictor import VisualizationDemo
import gradio as gr
from matplotlib.backends.backend_agg import FigureCanvasAgg as fc

# constants
WINDOW_NAME = "MaskFormer demo"


def setup_cfg(args):
    # load config from file and command-line arguments
    cfg = get_cfg()
    add_deeplab_config(cfg)
    add_cat_seg_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    return cfg


def get_parser():
    parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
    parser.add_argument(
        "--config-file",
        default="configs/vitl_swinb_384.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument(
        "--input",
        nargs="+",
        help="A list of space separated input images; "
        "or a single glob pattern such as 'directory/*.jpg'",
    )
    parser.add_argument(
        "--opts",
        help="Modify config options using the command-line 'KEY VALUE' pairs",
        default=["MODEL.WEIGHTS", "model_final.pth", 
        "MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON", "datasets/voc20.json",
        "MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON", "datasets/voc20.json",
        "TEST.SLIDING_WINDOW", "True",
        "MODEL.SEM_SEG_HEAD.POOLING_SIZES", "[1,1]"],
        nargs=argparse.REMAINDER,
    )
    return parser

def save_masks(preds, text):
    preds = preds['sem_seg'].argmax(dim=0).cpu().numpy() # C H W
    for i, t in enumerate(text):
        dir = f"masks/mask_{t}.png"
        mask = preds == i
        cv2.imwrite(dir, mask * 255)

def predict(image, text):
    args = get_parser().parse_args()
    cfg = setup_cfg(args)
    demo = VisualizationDemo(cfg, text=text)
    predictions, visualized_output = demo.run_on_image(image)
    save_masks(predictions, text.split(','))
    canvas = fc(visualized_output.fig)
    canvas.draw()
    out = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(canvas.get_width_height()[::-1] + (3,))

    return out[..., ::-1]

if __name__ == "__main__":
    args = get_parser().parse_args()
    cfg = setup_cfg(args)

    iface = gr.Interface(
        fn=predict,
        inputs=[gr.Image(), gr.Textbox(placeholder="Classes to segment")],
        outputs="image",
    )
    iface.launch()