Samarth991's picture
Update app.py
458bcca
raw
history blame
No virus
3.53 kB
from turtle import title
import gradio as gr
from transformers import pipeline
import numpy as np
from PIL import Image
import torch
import cv2
from matplotlib import pyplot as plt
from segmentation_mask_overlay import overlay_masks
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
classes = list()
def create_rgb_mask(mask):
color = tuple(np.random.choice(range(0,256), size=3))
gray_3_channel = cv2.merge((mask, mask, mask))
gray_3_channel[mask==255] = color
return gray_3_channel.astype(np.uint8)
def detect_using_clip(image,prompts=[],threshould=0.4):
predicted_masks = list()
inputs = processor(
text=prompts,
images=[image] * len(prompts),
padding="max_length",
return_tensors="pt",
)
with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation
outputs = model(**inputs)
preds = outputs.logits.unsqueeze(1)
for i,prompt in enumerate(prompts):
predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
predicted_image = np.where(predicted_image>threshould,255,0)
predicted_masks.append(predicted_image)
bool_masks = [predicted_mask.astype('bool') for predicted_mask in predicted_masks]
return bool_masks
def visualize_images(image,predicted_images,brightness=15,contrast=1.8):
alpha = 0.7
image_resize = cv2.resize(image,(352,352))
resize_image_copy = image_resize.copy()
# for mask_image in predicted_images:
# resize_image_copy = cv2.addWeighted(resize_image_copy,alpha,mask_image,1-alpha,10)
return cv2.convertScaleAbs(resize_image_copy, alpha=contrast, beta=brightness)
def shot(alpha,beta,image,labels_text):
if "," in labels_text:
prompts = labels_text.split(',')
else:
prompts = [labels_text]
prompts = list(map(lambda x: x.strip(),prompts))
mask_labels = [f"{prompt}_{i}" for i,prompt in enumerate(prompts)]
cmap = plt.cm.tab20(np.arange(len(mask_labels)))[..., :-1]
resize_image = cv2.resize(image,(352,352))
predicted_images = detect_using_clip(image,prompts=prompts)
category_image = overlay_masks(resize_image,np.stack(predicted_images,-1),labels=mask_labels,colors=cmap,alpha=alpha,beta=beta)
return category_image
iface = gr.Interface(fn=shot,
inputs = [
gr.Slider(1, 5, value=2, label="beta", info="Choose between 5 and 50"),
gr.Slider(0.1, 1, value=1, label="alpha", info="Choose between 1 and 5"),
"image",
"text"
],
outputs = "image",
description ="Add an Image and lists of category to be detected separated by commas(atleast 2 )",
title = "Zero-shot Image Segmentation with Prompt ",
examples=[
[19,1.5,"images/seats.jpg","door,table,chairs"],
[20,1.8,"images/vegetables.jpg","carrot,white radish,brinjal,basket,potato"],
[17,2,"images/room2.jpg","door, plants, dog, coffe table, table lamp, carpet, door"]
],
# allow_flagging=False,
# analytics_enabled=False,
)
iface.launch()