import torch import torchvision.transforms as transforms from torch.nn import functional as F import cv2 import gradio as gr import numpy as np from PIL import Image from pipline import Transformer_Regression, extract_regions_Last , compute_ratios device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ## Define some parameters image_shape = 384 #### 512 got 87 batch_size=1 dim_patch=4 num_classes=3 label_smoothing=0.1 scale=1 import time start = time.time() torch.manual_seed(0) #import random tfms = transforms.Compose([ transforms.Resize((image_shape, image_shape)), transforms.ToTensor(), transforms.Normalize(0.5,0.5) #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), #transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) ]) def Final_Compute_regression_results_Sample(Model, batch_sampler,num_head=2): Model.eval() score_cup = [] score_disc = [] yreg_pred = [] yreg_true = [] with torch.no_grad(): #for batch_sampler in loader: train_batch_tfms = batch_sampler['image'].to(device=device) #ytrue_seg = batch_sampler['image_original'] #.detach().cpu().numpy() ytrue_seg = batch_sampler['image_original'] # .detach().cpu().numpy() scores = Model(train_batch_tfms.unsqueeze(0)) yseg_pred = F.interpolate(scores['seg'], size=(ytrue_seg.shape[0], ytrue_seg.shape[1]), mode='bilinear', align_corners=True) # Regions_crop=extract_regions_Last(np.array(batch_sampler['image_original'][0]),yseg_pred[0].detach().cpu().numpy()) Regions_crop = extract_regions_Last(np.array(batch_sampler['image_original']), yseg_pred.argmax(1).long()[0].detach().cpu().numpy()) Regions_crop['image'] = Image.fromarray(np.uint8(Regions_crop['image'])).convert('RGB') ### Get back if two heads ytrue_seg_crop = ytrue_seg[Regions_crop['cord'][0]:Regions_crop['cord'][1], Regions_crop['cord'][2]:Regions_crop['cord'][3]] ytrue_seg_crop = np.expand_dims(ytrue_seg_crop, axis=0) if num_head==2: scores = Model((tfms(Regions_crop['image']).unsqueeze(0)).to(device)) yseg_pred_crop = F.interpolate(scores['seg_aux_1'], size=(ytrue_seg_crop.shape[1], ytrue_seg_crop.shape[2]), mode='bilinear', align_corners=True) yseg_pred[:, :, Regions_crop['cord'][0]:Regions_crop['cord'][1], Regions_crop['cord'][2]:Regions_crop['cord'][3]] = yseg_pred_crop # yseg_pred[:, :, Regions_crop['cord'][0]:Regions_crop['cord'][1], # Regions_crop['cord'][2]:Regions_crop['cord'][3]]+yseg_pred_crop yseg_pred = torch.softmax(yseg_pred, dim=1) yseg_pred = yseg_pred.argmax(1).long() yseg_pred = ((yseg_pred).long()).detach().cpu().numpy() ratios = compute_ratios(yseg_pred[0]) yreg_pred.append(ratios.vcdr) ### Plot p_img = batch_sampler['image'].to(device=device).unsqueeze(0) p_img = F.interpolate(p_img, size=(yseg_pred.shape[1], yseg_pred.shape[2]), mode='bilinear', align_corners=True) ### Get reversed image image_orig = (p_img[0] * 0.5 + 0.5).permute(1, 2, 0).detach().cpu().numpy() image_orig=np.uint8(image_orig*255) #### # train_batch_tfms #plt.imshow(image_orig) # make a copy as these operations are destructive image_cont = image_orig.copy() ###### plot for Prediction.... # threshold for 2 value ret, thresh = cv2.threshold(np.uint8(yseg_pred[0]), 1, 2, 0) # find and draw contour for 2 value (red) conts, hir = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) cv2.drawContours(image_cont, conts, -1, (0, 255, 0), 2) #threshold for 1 value ret, thresh = cv2.threshold(np.uint8(yseg_pred[0]), 0, 2, 0) #find and draw contour for 1 value (blue) conts, hir = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) cv2.drawContours(image_cont, conts, -1, (0, 0, 255), 2) #plot contoured image # plt.imshow(image_cont) # plt.axis('off') # print('Vertical cup to disc ratio:') # print(ratios.vcdr) if ratios.vcdr < 0.6: glaucoma = 'None' else: glaucoma = 'May be there is a risk of Glaucoma' # print('Galucoma:') return image_cont, ratios.vcdr, glaucoma, Regions_crop #load model DeepLab=Transformer_Regression(image_dim=image_shape,dim_patch=dim_patch,num_classes=3,scale=scale,feat_dim=128) DeepLab.to(device=device) DeepLab.load_state_dict(torch.load("TrainAll_Maghrabi84_50iteration_SWIN.pth.tar", map_location=torch.device(device))) def infer(img): # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) sample_batch = dict() sample_batch['image_original'] = img im_retina_pil = Image.fromarray(img) im_retina_pil = tfms(im_retina_pil) sample_batch['image'] = im_retina_pil # plt.figure('Head2') result, ratio, diagnosis, cropped = Final_Compute_regression_results_Sample(DeepLab, sample_batch, num_head=2) # cropped = cv2.cvtColor(np.asarray(cropped), cv2.COLOR_BGR2RGB) cropped = result[cropped['cord'][0] :cropped['cord'][1] , cropped['cord'][2] :cropped['cord'][3] ] return ratio, diagnosis, result, cropped title = "Glaucoma Detection in Retinal Fundus Images" description = "The method detects disc and cup in the retinal image, then it computes the Vertical cup to disc ratio" outputs = [gr.Textbox(label="Vertical cup to disc ratio:"), gr.Textbox(label="predicted diagnosis (Rule of thumb ~0.6 or greater is suspicious)"), gr.Image(label='labeled image'), gr.Image(label='zoomed in')] with gr.Blocks(css='#title {text-align : center;} ') as demo: with gr.Row(): gr.Markdown( f''' # {title} {description} ''', elem_id='title' ) with gr.Row(): with gr.Column(): prompt = gr.Image(label="Upload Your Retinal Fundus Image") btn = gr.Button(value='Submit') examples = gr.Examples( ['M00027.png','M00056.png','M00073.png','M00093.png', 'M00018.png', 'M00034.png'], inputs=[prompt], fn=infer, outputs=[outputs], cache_examples=False) with gr.Column(): with gr.Row(): text1 = gr.Textbox(label="Vertical Cup to Disc Ratio:") text2 = gr.Textbox(label="Predicted Diagnosis (Rule of thumb ~0.6 or greater is suspicious)") img = gr.Image(label='Detected disc and cup') zoom = gr.Image(label='Croppped') outputs = [text1,text2,img,zoom] btn.click(fn=infer, inputs=prompt, outputs=outputs) if __name__ == '__main__': demo.launch()