File size: 4,131 Bytes
d1bffba
 
 
 
 
 
90b20c2
d1bffba
90b20c2
d1d4db7
 
d1bffba
 
 
 
 
 
c2e6eeb
 
 
 
 
d1bffba
 
 
c2e6eeb
d1bffba
 
 
 
 
 
 
 
8dbc829
 
 
8753c71
8dbc829
 
 
d1bffba
8dbc829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2e6eeb
d367c2f
d1bffba
2677815
c2e6eeb
 
d1d4db7
 
c2e6eeb
d367c2f
d1bffba
458bcca
8dbc829
95f283f
1f10ad6
 
 
 
8dbc829
95f283f
d1bffba
d1d4db7
 
 
 
 
8dbc829
 
95f283f
c2e6eeb
d1bffba
 
 
d367c2f
a606ba4
8dbc829
d367c2f
 
 
d1bffba
18e03d9
4a60e71
d1bffba
9519ee9
45d981b
c6ba654
d1bffba
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from turtle import title
import gradio as gr
from transformers import pipeline
import numpy as np
from PIL import Image
import torch 
from torch import nn
import cv2 

from matplotlib import pyplot as plt 
from segmentation_mask_overlay import overlay_masks
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
classes = list()

def create_rgb_mask(mask):
    color = tuple(np.random.choice(range(0,256), size=3))
    gray_3_channel = cv2.merge((mask, mask, mask))
    gray_3_channel[mask==255] = color
    return gray_3_channel.astype(np.uint8)


def detect_using_clip(image,prompts=[],threshould=0.4):
    predicted_masks = list()
    inputs = processor(
        text=prompts,
        images=[image] * len(prompts),
        padding="max_length",
        return_tensors="pt",
    )
    with torch.no_grad():  # Use 'torch.no_grad()' to disable gradient computation
        outputs = model(**inputs)
    #preds = outputs.logits.unsqueeze(1)
    preds = nn.functional.interpolate(
        outputs.logits.unsqueeze(1),
        size=(image.shape[0], image.shape[1]),
        mode="bilinear"
        )
    threshold = 0.1

    flat_preds = torch.sigmoid(preds.squeeze()).reshape((preds.shape[0], -1))

    # Initialize a dummy "unlabeled" mask with the threshold
    flat_preds_with_treshold = torch.full((preds.shape[0] + 1, flat_preds.shape[-1]), threshold)
    flat_preds_with_treshold[1:preds.shape[0]+1,:] = flat_preds

    # Get the top mask index for each pixel
    inds = torch.topk(flat_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1]))
    predicted_masks = []

    for i in range(1, len(prompts)+1):
        mask =  np.where(inds==i,255,0)
        predicted_masks.append(mask)
    
    return predicted_masks

def visualize_images(image,predicted_images,brightness=15,contrast=1.8):
    alpha = 0.7
    image_resize = cv2.resize(image,(352,352))
    resize_image_copy = image_resize.copy()

    # for mask_image in predicted_images:
    #     resize_image_copy = cv2.addWeighted(resize_image_copy,alpha,mask_image,1-alpha,10)

    return cv2.convertScaleAbs(resize_image_copy, alpha=contrast, beta=brightness)     

def shot(alpha,beta,image,labels_text):
    print(labels_text)
    
    if "," in labels_text:
        prompts = labels_text.split(',')
    else:
        prompts = [labels_text]
    print(prompts)
    
    prompts = list(map(lambda x: x.strip(),prompts))

    mask_labels = [f"{prompt}_{i}" for i,prompt in enumerate(prompts)]
    cmap = plt.cm.tab20(np.arange(len(mask_labels)))[..., :-1]


    predicted_masks  = detect_using_clip(image,prompts=prompts)
    bool_masks = [predicted_mask.astype('bool') for predicted_mask in predicted_masks]
    category_image = overlay_masks(image,np.stack(bool_masks,-1),labels=mask_labels,colors=cmap,alpha=alpha,beta=beta)

    return category_image

iface = gr.Interface(fn=shot,
                    inputs = [
                        gr.Slider(0.1, 1, value=0.3, step=0.1 , label="alpha", info="Choose between 0.1 to 1"),
                        gr.Slider(0.1, 1, value=0.7, step=0.1, label="beta", info="Choose between 0.1 to 1"),
                        "image",
                        "text"
                        ],
                    outputs = "image",
                    description ="Add an Image and  labels to be detected separated by commas(atleast 2)",
                    title = "Zero-shot Image Segmentation with Prompt",
                    examples=[
                        [0.4,1,"images/room.jpg","chair, plant , flower pot , white cabinet , paintings , decorative plates , books"],
                        [0.4,1,"images/seats.jpg","door,table,chairs"],
                        [0.3,0.8,"images/vegetables.jpg","carrot,white radish,brinjal,basket,potato"]
                        ],
                    # allow_flagging=False, 
                    # analytics_enabled=False,
                )
iface.launch()