from turtle import title import gradio as gr from transformers import pipeline import numpy as np from PIL import Image import torch import cv2 from matplotlib import pyplot as plt from segmentation_mask_overlay import overlay_masks from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") classes = list() def create_rgb_mask(mask): color = tuple(np.random.choice(range(0,256), size=3)) gray_3_channel = cv2.merge((mask, mask, mask)) gray_3_channel[mask==255] = color return gray_3_channel.astype(np.uint8) def detect_using_clip(image,prompts=[],threshould=0.4): predicted_masks = list() inputs = processor( text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt", ) with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation outputs = model(**inputs) preds = outputs.logits.unsqueeze(1) for i,prompt in enumerate(prompts): predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy() predicted_image = np.where(predicted_image>threshould,255,0) predicted_masks.append(predicted_image) bool_masks = [predicted_mask.astype('bool') for predicted_mask in predicted_masks] return bool_masks def visualize_images(image,predicted_images,brightness=15,contrast=1.8): alpha = 0.7 image_resize = cv2.resize(image,(352,352)) resize_image_copy = image_resize.copy() # for mask_image in predicted_images: # resize_image_copy = cv2.addWeighted(resize_image_copy,alpha,mask_image,1-alpha,10) return cv2.convertScaleAbs(resize_image_copy, alpha=contrast, beta=brightness) def shot(alpha,beta,image,labels_text): if "," in labels_text: prompts = labels_text.split(',') else: prompts = [labels_text] prompts = list(map(lambda x: x.strip(),prompts)) mask_labels = [f"{prompt}_{i}" for i,prompt in enumerate(prompts)] cmap = plt.cm.tab20(np.arange(len(mask_labels)))[..., :-1] resize_image = cv2.resize(image,(352,352)) predicted_images = detect_using_clip(image,prompts=prompts) category_image = overlay_masks(resize_image,np.stack(predicted_images,-1),labels=mask_labels,colors=cmap,alpha=alpha,beta=beta) return category_image iface = gr.Interface(fn=shot, inputs = [ gr.Slider(1, 5, value=2, label="beta", info="Choose between 5 and 50"), gr.Slider(0.1, 1, value=1, label="alpha", info="Choose between 1 and 5"), "image", "text" ], outputs = "image", description ="Add an Image and lists of category to be detected separated by commas(atleast 2 )", title = "Zero-shot Image Segmentation with Prompt ", examples=[ [19,1.5,"images/seats.jpg","door,table,chairs"], [20,1.8,"images/vegetables.jpg","carrot,white radish,brinjal,basket,potato"], [17,2,"images/room2.jpg","door, plants, dog, coffe table, table lamp, carpet, door"] ], # allow_flagging=False, # analytics_enabled=False, ) iface.launch()