kushagra124 commited on
Commit
d1bffba
1 Parent(s): 297686d

adding app with CLIP image segmentation

Browse files
Files changed (4) hide show
  1. app.py +93 -0
  2. images/image2.png +0 -0
  3. images/room.jpg +0 -0
  4. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from turtle import title
2
+ import os
3
+ import gradio as gr
4
+ from transformers import pipeline
5
+ import numpy as np
6
+ from PIL import Image
7
+ import torch
8
+ import cv2
9
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig
10
+ from skimage.measure import label, regionprops
11
+
12
+ processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
13
+ model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
14
+ classes = list()
15
+
16
+ def create_mask(image,image_mask,alpha=0.7):
17
+ mask = np.zeros_like(image)
18
+ # copy your image_mask to all dimensions (i.e. colors) of your image
19
+ for i in range(3):
20
+ mask[:,:,i] = image_mask.copy()
21
+ # apply the mask to your image
22
+ overlay_image = cv2.addWeighted(mask,alpha,image,1-alpha,0)
23
+ return overlay_image
24
+
25
+ def rescale_bbox(bbox,orig_image_shape=(1024,1024),model_shape=352):
26
+ bbox = np.asarray(bbox)/model_shape
27
+ y1,y2 = bbox[::2] *orig_image_shape[0]
28
+ x1,x2 = bbox[1::2]*orig_image_shape[1]
29
+ return [int(y1),int(x1),int(y2),int(x2)]
30
+
31
+ def detect_using_clip(image,prompts=[],threshould=0.4):
32
+ model_detections = dict()
33
+ predicted_images = dict()
34
+ inputs = processor(
35
+ text=prompts,
36
+ images=[image] * len(prompts),
37
+ padding="max_length",
38
+ return_tensors="pt",
39
+ )
40
+ with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation
41
+ outputs = model(**inputs)
42
+ preds = outputs.logits.unsqueeze(1)
43
+
44
+ detection = outputs.logits[0] # Assuming class index 0
45
+ for i,prompt in enumerate(prompts):
46
+ predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
47
+ predicted_image = np.where(predicted_image>threshould,255,0)
48
+ # extract countours from the image
49
+ lbl_0 = label(predicted_image)
50
+ props = regionprops(lbl_0)
51
+ prompt = prompt.lower()
52
+ model_detections[prompt] = [rescale_bbox(prop.bbox,orig_image_shape=image.shape[:2],model_shape=predicted_image.shape[0]) for prop in props]
53
+ predicted_images[prompt]= cv2.resize(predicted_image,image.shape[:2])
54
+ return model_detections , predicted_images
55
+
56
+ def visualize_images(image,detections,predicted_image,prompt):
57
+ alpha = 0.7
58
+ H,W = image.shape[:2]
59
+ prompt = prompt.lower()
60
+ image_copy = image.copy()
61
+ mask_image = create_mask(image=image_copy,image_mask=predicted_image)
62
+
63
+ if prompt not in detections.keys():
64
+ print("prompt not in query ..")
65
+ return image_copy
66
+ for bbox in detections[prompt]:
67
+ cv2.rectangle(image_copy, (int(bbox[1]), int(bbox[0])), (int(bbox[3]), int(bbox[2])), (255, 0, 0), 2)
68
+ cv2.putText(image_copy,str(prompt),(int(bbox[1]), int(bbox[0])),cv2.FONT_HERSHEY_SIMPLEX, 2, 255)
69
+ final_image = cv2.addWeighted(image_copy,alpha,mask_image,1-alpha,0)
70
+ return final_image
71
+
72
+ def shot(image, labels_text,selected_categoty):
73
+ prompts = labels_text.split(',')
74
+ prompts = list(map(lambda x: x.strip(),prompts))
75
+
76
+ model_detections,predicted_images = detect_using_clip(image,prompts=prompts)
77
+
78
+ category_image = visualize_images(image=image,detections=model_detections,predicted_image=predicted_images,prompt=selected_categoty)
79
+ return category_image
80
+
81
+ iface = gr.Interface(fn=shot,
82
+ inputs = ["image","text","text"],
83
+ outputs = "image",
84
+ description ="Add an Image and list of category to be detected separated by commas",
85
+ title = "Zero-shot Image Classification with Prompt ",
86
+ examples=[
87
+ ["images/room.jpg","bed, table, plant, light, window",'plant'],
88
+ ["images/image2.png","banner, building,door, sign","sign"]
89
+ ],
90
+ # allow_flagging=False,
91
+ # analytics_enabled=False,
92
+ )
93
+ iface.launch()
images/image2.png ADDED
images/room.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ sentencepiece
4
+ huggingface_hub
5
+ numpy
6
+ scikit-image
7
+ opencv-python
8
+ Pillow
9
+ requests
10
+ urllib3<2
11
+ git+https://github.com/facebookresearch/segment-anything.git