import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader, Subset from torchvision import transforms, datasets from PIL import Image from tqdm.auto import tqdm import torch.nn.functional as F from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, FullGrad from matplotlib import colormaps import numpy as np import gradio as gr class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() # Convolutional layers self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) # Pooling layer self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) # Fully connected layers self.fc1 = nn.Linear(64 * (224 // 8) * (224 // 8), 64) # Adjusted based on pooling layers self.fc2 = nn.Linear(64, 2) # 2 classes for binary classification def forward(self, x): # Convolutional layers with relu activation and pooling x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = self.pool(F.relu(self.conv3(x))) # Flatten for fully connected layers x = torch.flatten(x, 1) # Fully connected layers with relu activation x = F.relu(self.fc1(x)) x = self.fc2(x) return x transform = transforms.Compose([ transforms.Resize((224, 224)), # Resize to 224x224 transforms.ToTensor(), # Convert to tensor transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize ]) model = CNN() model.load_state_dict(torch.load("trained-cnn-concrete-crack.model", map_location=torch.device("cpu"))) magmaify = colormaps['magma'] def compute_gradcam(img_tensor, layer_idx, typeCAM): allCAMs = {"GradCAM": GradCAM, "HiResCAM": HiResCAM, "ScoreCAM": ScoreCAM, "GradCAMPlusPlus": GradCAMPlusPlus, "AblationCAM": AblationCAM, "XGradCAM": XGradCAM, "FullGrad": FullGrad} target_layers = [[model.conv1], [model.conv2], [model.conv3]] cam = allCAMs[typeCAM](model=model, target_layers=target_layers[layer_idx-1]) grayscale_cam = cam(input_tensor=img_tensor, targets=None) return magmaify(grayscale_cam.reshape(224, 224)) def predict_and_gradcam(model, img, layer_idx, typeCAM): # Preprocess the image img = Image.fromarray(img.astype('uint8'), 'RGB') if isinstance(img, np.ndarray) else img img_tensor = transform(img).unsqueeze(0) # Get predicted class index with torch.no_grad(): output = model(img_tensor) _, predicted = torch.max(output.data, 1) predicted_label = str(predicted.item()) # Compute GradCAM gradcam = compute_gradcam(img_tensor, layer_idx, typeCAM) return predicted_label, gradcam idx_to_lbl = {"0": "Cracked", "1":"Uncracked"} # Define a function to be used in Gradio app def classify_image(image, layer_idx, typeCAM): # Predict label and get GradCAM label, gradcam_img = predict_and_gradcam(model, image, layer_idx, typeCAM) return idx_to_lbl[label], gradcam_img description = """\
Upload an image of concrete and get the predicted label along with the GradCAM heatmap.

\ """ typeCAMs = ["GradCAM", "HiResCAM", "ScoreCAM", "GradCAMPlusPlus", "AblationCAM", "XGradCAM", "FullGrad"] # Define Gradio interface iface = gr.Interface( fn=classify_image, inputs=[gr.Image(), gr.Slider(minimum=1, maximum=3, step=1, value=1), gr.Dropdown(choices=typeCAMs, value="GradCAM")], outputs=[gr.Textbox(label="Predicted Label"), gr.Image(label="GradCAM Heatmap")], title="Concrete Crack Detection with GradCAM", description= description, allow_flagging=False, theme=gr.themes.Monochrome(font=gr.themes.GoogleFont("IBM Plex Mono")) ) # Launch the interface iface.launch()