Xp-age / app.py
MedicalAILabo's picture
Update app.py
dd2f551
raw
history blame
No virus
3.22 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
import gradio as gr
from lib import create_model
from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group
from lib.dataloader import ImageMixin
test_weight = './weight_epoch-200_best.pt'
parameter = './parameters.json'
class ImageHandler(ImageMixin):
def __init__(self, params):
self.params = params
self.transform = self._make_transforms()
def set_image(self, image):
image = self.transform(image)
image = {'image': image.unsqueeze(0)}
return image
def load_parameter(parameter):
_args = ParamSet()
params = _retrieve_parameter(parameter)
for _param, _arg in params.items():
setattr(_args, _param, _arg)
_args.augmentation = 'no'
_args.sampler = 'no'
_args.pretrained = False
_args.mlp = None
_args.net = _args.model
_args.device = torch.device('cpu')
args_model = _dispatch_by_group(_args, 'model')
args_dataloader = _dispatch_by_group(_args, 'dataloader')
return args_model, args_dataloader
args_model, args_dataloader = load_parameter(parameter)
model = create_model(args_model)
model.load_weight(test_weight)
def main(image):
model.eval()
image_handler = ImageHandler(args_dataloader)
image = image_handler.set_image(image)
with torch.no_grad():
outputs = model(image)
label_name = list(outputs.keys())[0]
result = outputs[label_name].detach().numpy().item()
result = f"{result:.2f}"
return result
html_content = """
<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;">
<h3>Image preprocess</h3>
<p>Only grayscale 320×320 resolution works appropriately.</p>
<p>The longest side of the Xp was downscaled to 320 pixels while maintaining the aspect ratio,
and the width along the shorter side was then padded black to 320 pixels.
</p>
<h3>Publication Details</h3>
<p>See details in our publication, titled
"Chest radiography as a biomarker of ageing: artificial intelligence-based,
multi-institutional model development and validation in Japan"
</p>
<p><strong>Link:</strong> <a href="https://www.thelancet.com/journals/lanhl/article/PIIS2666-7568(23)00133-2/fulltext" target="_blank">
https://www.thelancet.com/journals/lanhl/article/PIIS2666-7568(23)00133-2/fulltext
</a></p>
</div>
"""
# Gradio
with gr.Blocks(title="Aging Biomarker from CXR",
css=".gradio-container {background:mintcream;}"
) as demo:
gr.HTML("""<div style="text-align:center"><h2>Aging Biomarker from CXR</h2></div>""")
gr.HTML(html_content)
with gr.Row():
input_image = gr.Image(type="pil", image_mode="L", shape=(320, 320))
output_label=gr.Label(label="Estimated age")
send_btn = gr.Button("Inference")
send_btn.click(fn=main, inputs=input_image, outputs=output_label)
with gr.Row():
gr.Examples(['./samples/66_female_xp.png'], label='Sample CXR 1: 66 years old female', inputs=input_image)
gr.Examples(['./samples/28_male_xp.png'], label='Sample CXR 2: 28 years old male', inputs=input_image)
demo.launch(debug=True)