bigmed@bigmed
moved the rule of thumb to title
5acc242
raw
history blame
No virus
7.28 kB
import torch
import torchvision.transforms as transforms
from torch.nn import functional as F
import cv2
import gradio as gr
import numpy as np
from PIL import Image
from pipline import Transformer_Regression, extract_regions_Last , compute_ratios
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
## Define some parameters
image_shape = 384 #### 512 got 87
batch_size=1
dim_patch=4
num_classes=3
label_smoothing=0.1
scale=1
import time
start = time.time()
torch.manual_seed(0)
#import random
tfms = transforms.Compose([
transforms.Resize((image_shape, image_shape)),
transforms.ToTensor(),
transforms.Normalize(0.5,0.5)
#transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
#transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
])
def Final_Compute_regression_results_Sample(Model, batch_sampler,num_head=2):
Model.eval()
score_cup = []
score_disc = []
yreg_pred = []
yreg_true = []
with torch.no_grad():
#for batch_sampler in loader:
train_batch_tfms = batch_sampler['image'].to(device=device)
#ytrue_seg = batch_sampler['image_original'] #.detach().cpu().numpy()
ytrue_seg = batch_sampler['image_original'] # .detach().cpu().numpy()
scores = Model(train_batch_tfms.unsqueeze(0))
yseg_pred = F.interpolate(scores['seg'], size=(ytrue_seg.shape[0], ytrue_seg.shape[1]), mode='bilinear',
align_corners=True)
# Regions_crop=extract_regions_Last(np.array(batch_sampler['image_original'][0]),yseg_pred[0].detach().cpu().numpy())
Regions_crop = extract_regions_Last(np.array(batch_sampler['image_original']),
yseg_pred.argmax(1).long()[0].detach().cpu().numpy())
Regions_crop['image'] = Image.fromarray(np.uint8(Regions_crop['image'])).convert('RGB')
### Get back if two heads
ytrue_seg_crop = ytrue_seg[Regions_crop['cord'][0]:Regions_crop['cord'][1],
Regions_crop['cord'][2]:Regions_crop['cord'][3]]
ytrue_seg_crop = np.expand_dims(ytrue_seg_crop, axis=0)
if num_head==2:
scores = Model((tfms(Regions_crop['image']).unsqueeze(0)).to(device))
yseg_pred_crop = F.interpolate(scores['seg_aux_1'], size=(ytrue_seg_crop.shape[1], ytrue_seg_crop.shape[2]),
mode='bilinear', align_corners=True)
yseg_pred[:, :, Regions_crop['cord'][0]:Regions_crop['cord'][1],
Regions_crop['cord'][2]:Regions_crop['cord'][3]] = yseg_pred_crop
# yseg_pred[:, :, Regions_crop['cord'][0]:Regions_crop['cord'][1],
# Regions_crop['cord'][2]:Regions_crop['cord'][3]]+yseg_pred_crop
yseg_pred = torch.softmax(yseg_pred, dim=1)
yseg_pred = yseg_pred.argmax(1).long()
yseg_pred = ((yseg_pred).long()).detach().cpu().numpy()
ratios = compute_ratios(yseg_pred[0])
yreg_pred.append(ratios.vcdr)
### Plot
p_img = batch_sampler['image'].to(device=device).unsqueeze(0)
p_img = F.interpolate(p_img, size=(yseg_pred.shape[1], yseg_pred.shape[2]),
mode='bilinear', align_corners=True)
### Get reversed image
image_orig = (p_img[0] * 0.5 + 0.5).permute(1, 2, 0).detach().cpu().numpy()
image_orig=np.uint8(image_orig*255)
####
# train_batch_tfms
#plt.imshow(image_orig)
# make a copy as these operations are destructive
image_cont = image_orig.copy()
###### plot for Prediction....
# threshold for 2 value
ret, thresh = cv2.threshold(np.uint8(yseg_pred[0]), 1, 2, 0)
# find and draw contour for 2 value (red)
conts, hir = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
cv2.drawContours(image_cont, conts, -1, (0, 255, 0), 2)
#threshold for 1 value
ret, thresh = cv2.threshold(np.uint8(yseg_pred[0]), 0, 2, 0)
#find and draw contour for 1 value (blue)
conts, hir = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
cv2.drawContours(image_cont, conts, -1, (0, 0, 255), 2)
#plot contoured image
# plt.imshow(image_cont)
# plt.axis('off')
# print('Vertical cup to disc ratio:')
# print(ratios.vcdr)
if ratios.vcdr < 0.6:
glaucoma = 'None'
else:
glaucoma = 'May be there is a risk of Glaucoma'
# print('Galucoma:')
return image_cont, ratios.vcdr, glaucoma, Regions_crop
#load model
DeepLab=Transformer_Regression(image_dim=image_shape,dim_patch=dim_patch,num_classes=3,scale=scale,feat_dim=128)
DeepLab.to(device=device)
DeepLab.load_state_dict(torch.load("TrainAll_Maghrabi84_50iteration_SWIN.pth.tar", map_location=torch.device(device)))
def infer(img):
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
sample_batch = dict()
sample_batch['image_original'] = img
im_retina_pil = Image.fromarray(img)
im_retina_pil = tfms(im_retina_pil)
sample_batch['image'] = im_retina_pil
# plt.figure('Head2')
result, ratio, diagnosis, cropped = Final_Compute_regression_results_Sample(DeepLab, sample_batch, num_head=2)
# cropped = cv2.cvtColor(np.asarray(cropped), cv2.COLOR_BGR2RGB)
cropped = result[cropped['cord'][0] :cropped['cord'][1] ,
cropped['cord'][2] :cropped['cord'][3] ]
return ratio, diagnosis, result, cropped
title = "Glaucoma Detection in Retinal Fundus Images"
description = "The method detects disc and cup in the retinal image, then it computes the Vertical cup to disc ratio"
outputs = [gr.Textbox(label="Vertical cup to disc ratio:"), gr.Textbox(label="predicted diagnosis (Rule of thumb ~0.6 or greater is suspicious)"), gr.Image(label='labeled image'), gr.Image(label='zoomed in')]
with gr.Blocks(css='#title {text-align : center;} ') as demo:
with gr.Row():
gr.Markdown(
f'''
# {title}
{description}
''',
elem_id='title'
)
with gr.Row():
with gr.Column():
prompt = gr.Image(label="Upload Your Retinal Fundus Image")
btn = gr.Button(value='Submit')
examples = gr.Examples(
['M00027.png','M00056.png','M00073.png','M00093.png', 'M00018.png', 'M00034.png'],
inputs=[prompt], fn=infer, outputs=[outputs], cache_examples=False)
with gr.Column():
with gr.Row():
text1 = gr.Textbox(label="Vertical Cup to Disc Ratio:")
text2 = gr.Textbox(label="Predicted Diagnosis (Rule of thumb ~0.6 or greater is suspicious)")
img = gr.Image(label='Detected disc and cup')
zoom = gr.Image(label='Croppped')
outputs = [text1,text2,img,zoom]
btn.click(fn=infer, inputs=prompt, outputs=outputs)
if __name__ == '__main__':
demo.launch()