Irsh Vijayvargia
First Commit
fc055d6
raw
history blame
No virus
4.05 kB
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, datasets
from PIL import Image
from tqdm.auto import tqdm
import torch.nn.functional as F
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, FullGrad
from matplotlib import colormaps
import numpy as np
import gradio as gr
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
# Convolutional layers
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
# Pooling layer
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
# Fully connected layers
self.fc1 = nn.Linear(64 * (224 // 8) * (224 // 8), 64) # Adjusted based on pooling layers
self.fc2 = nn.Linear(64, 2) # 2 classes for binary classification
def forward(self, x):
# Convolutional layers with relu activation and pooling
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
# Flatten for fully connected layers
x = torch.flatten(x, 1)
# Fully connected layers with relu activation
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize to 224x224
transforms.ToTensor(), # Convert to tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize
])
model = CNN()
model.load_state_dict(torch.load("trained-cnn-concrete-crack.model", map_location=torch.device("cpu")))
magmaify = colormaps['magma']
def compute_gradcam(img_tensor, layer_idx, typeCAM):
allCAMs = {"GradCAM": GradCAM, "HiResCAM": HiResCAM, "ScoreCAM": ScoreCAM, "GradCAMPlusPlus": GradCAMPlusPlus, "AblationCAM": AblationCAM, "XGradCAM": XGradCAM, "FullGrad": FullGrad}
target_layers = [[model.conv1], [model.conv2], [model.conv3]]
cam = allCAMs[typeCAM](model=model, target_layers=target_layers[layer_idx-1])
grayscale_cam = cam(input_tensor=img_tensor, targets=None)
return magmaify(grayscale_cam.reshape(224, 224))
def predict_and_gradcam(model, img, layer_idx, typeCAM):
# Preprocess the image
img = Image.fromarray(img.astype('uint8'), 'RGB') if isinstance(img, np.ndarray) else img
img_tensor = transform(img).unsqueeze(0)
# Get predicted class index
with torch.no_grad():
output = model(img_tensor)
_, predicted = torch.max(output.data, 1)
predicted_label = str(predicted.item())
# Compute GradCAM
gradcam = compute_gradcam(img_tensor, layer_idx, typeCAM)
return predicted_label, gradcam
idx_to_lbl = {"0": "Cracked", "1":"Uncracked"}
# Define a function to be used in Gradio app
def classify_image(image, layer_idx, typeCAM):
# Predict label and get GradCAM
label, gradcam_img = predict_and_gradcam(model, image, layer_idx, typeCAM)
return idx_to_lbl[label], gradcam_img
description = """\
<center>Upload an image of concrete and get the predicted label along with the GradCAM heatmap.</center>
<img src="https://www.huggingface.co/spaces/concrete-crack-gradcam/main/resolve/header.jpeg"></img>
\
"""
typeCAMs = ["GradCAM", "HiResCAM", "ScoreCAM", "GradCAMPlusPlus", "AblationCAM", "XGradCAM", "FullGrad"]
# Define Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=[gr.Image(), gr.Slider(minimum=1, maximum=3, step=1, value=1), gr.Dropdown(choices=typeCAMs, value="GradCAM")],
outputs=[gr.Textbox(label="Predicted Label"), gr.Image(label="GradCAM Heatmap")],
title="Concrete Crack Detection with GradCAM",
description= description,
allow_flagging=False
)
# Launch the interface
iface.launch()