anilbhatt1 commited on
Commit
80d1ab2
1 Parent(s): 902ac2f

Initial commit

Browse files
Files changed (3) hide show
  1. .gitignore +7 -0
  2. app.py +228 -0
  3. requirements.txt +4 -0
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio_cached_examples/
2
+ *.png
3
+ *.jpg
4
+ flagged/
5
+ *.pt
6
+ *.json
7
+ *.npy
app.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+ import json
6
+ import gradio as gr
7
+ import matplotlib.pyplot as plt
8
+
9
+ import torch
10
+ from torch import nn
11
+ import torch.nn.functional as F
12
+ import timm
13
+ from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
14
+
15
+ class CFG:
16
+ image_path = './images'
17
+ captions_path = './captions'
18
+ batch_size = 64
19
+ num_workers = 4
20
+ head_lr = 1e-3
21
+ image_encoder_lr = 1e-4
22
+ text_encoder_lr = 1e-5
23
+ weight_decay = 1e-3
24
+ patience = 1
25
+ factor = 0.8
26
+ epochs = 2
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+
29
+ model_name = 'resnet50'
30
+ image_embedding = 2048
31
+ text_encoder_model = "distilbert-base-uncased"
32
+ text_embedding = 768
33
+ text_tokenizer = "distilbert-base-uncased"
34
+ max_length = 200
35
+
36
+ pretrained = True # for both image encoder and text encoder
37
+ trainable = True # for both image encoder and text encoder
38
+ temperature = 1.0
39
+
40
+ # image size
41
+ size = 224
42
+
43
+ # for projection head; used for both image and text encoders
44
+ num_projection_layers = 1
45
+ projection_dim = 256
46
+ dropout = 0.1
47
+
48
+ class ImageEncoder(nn.Module):
49
+ """
50
+ Encode images to a fixed size vector
51
+ """
52
+
53
+ def __init__(
54
+ self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
55
+ ):
56
+ super().__init__()
57
+ self.model = timm.create_model(
58
+ model_name, pretrained, num_classes=0, global_pool="avg"
59
+ )
60
+ for p in self.model.parameters():
61
+ p.requires_grad = trainable
62
+
63
+ def forward(self, x):
64
+ return self.model(x)
65
+
66
+ class TextEncoder(nn.Module):
67
+ def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
68
+ super().__init__()
69
+ if pretrained:
70
+ self.model = DistilBertModel.from_pretrained(model_name)
71
+ else:
72
+ self.model = DistilBertModel(config=DistilBertConfig())
73
+
74
+ for p in self.model.parameters():
75
+ p.requires_grad = trainable
76
+
77
+ # we are using the CLS token hidden representation as the sentence's embedding
78
+ self.target_token_idx = 0
79
+
80
+ def forward(self, input_ids, attention_mask):
81
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)
82
+ last_hidden_state = output.last_hidden_state
83
+ return last_hidden_state[:, self.target_token_idx, :]
84
+
85
+ class ProjectionHead(nn.Module):
86
+ def __init__(
87
+ self,
88
+ embedding_dim,
89
+ projection_dim=CFG.projection_dim,
90
+ dropout=CFG.dropout
91
+ ):
92
+ super().__init__()
93
+ self.projection = nn.Linear(embedding_dim, projection_dim)
94
+ self.gelu = nn.GELU()
95
+ self.fc = nn.Linear(projection_dim, projection_dim)
96
+ self.dropout = nn.Dropout(dropout)
97
+ self.layer_norm = nn.LayerNorm(projection_dim)
98
+
99
+ def forward(self, x):
100
+ projected = self.projection(x)
101
+ x = self.gelu(projected)
102
+ x = self.fc(x)
103
+ x = self.dropout(x)
104
+ x = x + projected
105
+ x = self.layer_norm(x)
106
+ return x
107
+
108
+ class CLIPModel(nn.Module):
109
+ def __init__(
110
+ self,
111
+ temperature=CFG.temperature,
112
+ image_embedding=CFG.image_embedding,
113
+ text_embedding=CFG.text_embedding,
114
+ ):
115
+ super().__init__()
116
+ self.image_encoder = ImageEncoder()
117
+ self.text_encoder = TextEncoder()
118
+ self.image_projection = ProjectionHead(embedding_dim=image_embedding)
119
+ self.text_projection = ProjectionHead(embedding_dim=text_embedding)
120
+ self.temperature = temperature
121
+
122
+ def forward(self, batch):
123
+ # Getting Image and Text Features
124
+ image_features = self.image_encoder(batch["image"])
125
+ text_features = self.text_encoder(
126
+ input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
127
+ )
128
+ # Getting Image and Text Embeddings (with same dimension)
129
+ image_embeddings = self.image_projection(image_features)
130
+ text_embeddings = self.text_projection(text_features)
131
+
132
+ # Calculating the Loss
133
+ images_similarity = image_embeddings @ text_embeddings.T / self.temperature
134
+ texts_similarity = images_similarity.T
135
+ labels = torch.arange(batch["image"].shape[0]).long().to(CFG.device)
136
+
137
+ total_loss = (
138
+ F.cross_entropy(images_similarity, labels) +
139
+ F.cross_entropy(texts_similarity, labels)
140
+ ) / 2
141
+
142
+ return total_loss
143
+
144
+ def find_matches_cpu(model, image_embeddings, query, image_filenames, n=4):
145
+ tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
146
+ encoded_query = tokenizer([query])
147
+ batch = {
148
+ key: torch.tensor(values).to('cpu')
149
+ for key, values in encoded_query.items()
150
+ }
151
+ with torch.no_grad():
152
+ text_features = model.text_encoder(
153
+ input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
154
+ )
155
+ text_embeddings = model.text_projection(text_features)
156
+
157
+ image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
158
+ text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
159
+ dot_similarity = text_embeddings_n @ image_embeddings_n.T
160
+
161
+ values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
162
+ matches = [image_filenames[idx] for idx in indices[::5]]
163
+ return matches
164
+
165
+ def rle_decode(img_rle_array, img_name, img_size):
166
+ encoded_image = img_rle_array
167
+ # Initialize variables for decoding
168
+ decoded_image = []
169
+ for i in range(0, len(encoded_image), 2):
170
+ pixel_value = encoded_image[i]
171
+ run_length = encoded_image[i + 1]
172
+ decoded_image.extend([pixel_value] * run_length)
173
+
174
+ # Convert the decoded image back to a NumPy array
175
+ decoded_array = np.array(decoded_image, dtype=np.uint8)
176
+
177
+ # Reshape the decoded array to the original image shape (224, 224)
178
+ decoded_image = decoded_array.reshape(img_size) # Use original shape
179
+
180
+ # Create a PIL Image from the decoded array
181
+ decoded_image = Image.fromarray(decoded_image)
182
+
183
+ decoded_image_save_path = './' + str(img_name)
184
+ # Save or display the decoded image
185
+ decoded_image.save(decoded_image_save_path) # Save the decoded image to a file
186
+ return decoded_image_save_path
187
+
188
+ def get_matched_image(matches, val_file_dict_loaded):
189
+ img_size = (112, 112)
190
+ match_img_list = []
191
+ for img_name in matches:
192
+ img_rle_array = val_file_dict_loaded[img_name]
193
+ decoded_image_path = rle_decode(img_rle_array, img_name, img_size)
194
+ match_img_list.append(decoded_image_path)
195
+ return match_img_list
196
+
197
+ def get_grayscale_image(text_query):
198
+ model_inf = CLIPModel().to('cpu')
199
+ model_inf.load_state_dict(torch.load('./best_clip_model_cpu.pt', map_location='cpu'))
200
+
201
+ clip_image_embeddings_np_inf = np.load('./clip_image_embeddings.npy')
202
+ image_embeddings_inf = torch.tensor(clip_image_embeddings_np_inf)
203
+
204
+ img_file_names = np.load('./val_img_file_names.npy',allow_pickle=True)
205
+
206
+ with open("./val_imgs_rle_encode.json", "r") as json_file:
207
+ val_file_dict_loaded = json.load(json_file)
208
+
209
+ matches = find_matches_cpu(model_inf,
210
+ image_embeddings_inf,
211
+ query=text_query,
212
+ image_filenames=img_file_names,
213
+ n=1)
214
+
215
+ matched_images = get_matched_image(matches, val_file_dict_loaded)
216
+ return matched_images
217
+
218
+ def gradio_fn(text):
219
+ text_query = str(text)
220
+ match_img_list = get_grayscale_image(text_query)
221
+ pil_img = Image.open(match_img_list[0])
222
+ pil_img = pil_img.resize((224, 224))
223
+ np_img_array = np.array(pil_img)
224
+ return np_img_array
225
+
226
+ demo = gr.Interface(fn=gradio_fn, inputs="text", outputs="image", title="CLIP Image Search")
227
+
228
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers
2
+ gradio
3
+ timm
4
+ torch