from turtle import title import gradio as gr from transformers import pipeline import numpy as np from PIL import Image import torch import cv2 from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") classes = list() def create_rgb_mask(mask): color = tuple(np.random.choice(range(0,256), size=3)) gray_3_channel = cv2.merge((mask, mask, mask)) gray_3_channel[mask==255] = color return gray_3_channel.astype(np.uint8) def detect_using_clip(image,prompts=[],threshould=0.4): predicted_masks = list() inputs = processor( text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt", ) with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation outputs = model(**inputs) preds = outputs.logits.unsqueeze(1) for i,prompt in enumerate(prompts): predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy() predicted_image = np.where(predicted_image>threshould,255,0) predicted_masks.append(create_rgb_mask(predicted_image)) return predicted_masks def visualize_images(image,predicted_images,brightness=15): alpha = 0.7 image_resize = cv2.resize(image,(352,352)) resize_image_copy = image_resize.copy() for mask_image in predicted_images: resize_image_copy = cv2.addWeighted(resize_image_copy,alpha,mask_image,1-alpha,10) return cv2.convertScaleAbs(resize_image_copy, alpha=1.8, beta=brightness) def shot(brightness,image, labels_text): if "," in labels_text: prompts = labels_text.split(',') else: prompts = [labels_text] prompts = list(map(lambda x: x.strip(),prompts)) predicted_images = detect_using_clip(image,prompts=prompts) category_image = visualize_images(image=image,predicted_images=predicted_images,brightness=brightness) return category_image iface = gr.Interface(fn=shot, inputs = [gr.Slider(5, 50, value=15, label="Brightness", info="Choose between 5 and 50"),"image","text"], outputs = "image", description ="Add an Image and lists of category to be detected separated by commas(atleast 2 )", title = "Zero-shot Image Segmentation with Prompt ", examples=[ [15,"images/room.jpg","bed, table, plant, light, window,light"], [10,"images/image2.png","banner, building,door, sign,"], [19,"images/seats.jpg","door,table,chairs"], [20,"images/vegetables.jpg","carrot,radish,beans,potato,brnjal,basket"], [17,"images/room2.jpg","door,platns,dog,coffe table,mug,pillow,table lamp,carpet,pictures,door,clock"] ], # allow_flagging=False, # analytics_enabled=False, ) iface.launch()