Xp-age / app.py
MedicalAILabo's picture
Update app.py
a94c67d
raw
history blame
No virus
3.23 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
import gradio as gr
from lib import create_model
from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group
from lib.dataloader import ImageMixin
test_weight = './weight_epoch-200_best.pt'
parameter = './parameters.json'
class ImageHandler(ImageMixin):
def __init__(self, params):
self.params = params
self.transform = self._make_transforms()
def set_image(self, image):
image = self.transform(image)
image = {'image': image.unsqueeze(0)}
return image
def load_parameter(parameter):
_args = ParamSet()
params = _retrieve_parameter(parameter)
for _param, _arg in params.items():
setattr(_args, _param, _arg)
_args.augmentation = 'no'
_args.sampler = 'no'
_args.pretrained = False
_args.mlp = None
_args.net = _args.model
_args.device = torch.device('cpu')
args_model = _dispatch_by_group(_args, 'model')
args_dataloader = _dispatch_by_group(_args, 'dataloader')
return args_model, args_dataloader
args_model, args_dataloader = load_parameter(parameter)
model = create_model(args_model)
model.load_weight(test_weight)
def main(image):
model.eval()
image_handler = ImageHandler(args_dataloader)
image = image_handler.set_image(image)
with torch.no_grad():
outputs = model(image)
label_name = list(outputs.keys())[0]
result = outputs[label_name].detach().numpy().item()
result = f"{result:.2f}"
return result
html_content = """
<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;">
<h3>Image preprocess</h3>
<p>Only grayscale 320×320 resolution works appropriately.</p>
<p>The longest side of the Xp should be downscaled to 320 pixels while maintaining the aspect ratio,
and the width along the shorter side should be padded black to 320 pixels.
</p>
<h3>Publication Details</h3>
<p>See details in our publication, titled
"Chest radiography as a biomarker of ageing: artificial intelligence-based,
multi-institutional model development and validation in Japan"
</p>
<p><strong>Link:</strong> <a href="https://www.thelancet.com/journals/lanhl/article/PIIS2666-7568(23)00133-2/fulltext" target="_blank">
https://www.thelancet.com/journals/lanhl/article/PIIS2666-7568(23)00133-2/fulltext
</a></p>
</div>
"""
# Gradio
with gr.Blocks(title="Aging Biomarker from CXR",
css=".gradio-container {background:mintcream;}"
) as demo:
gr.HTML("""<div style="text-align:center"><h2>Aging Biomarker from CXR</h2></div>""")
gr.HTML(html_content)
with gr.Row():
input_image = gr.Image(type="pil", image_mode="L", shape=(320, 320))
output_label=gr.Label(label="Estimated age")
send_btn = gr.Button("Inference")
send_btn.click(fn=main, inputs=input_image, outputs=output_label)
with gr.Row():
gr.Examples(['./samples/66_female_xp.png'], label='Sample CXR 1: 66 years old female', inputs=input_image)
gr.Examples(['./samples/28_male_xp.png'], label='Sample CXR 2: 28 years old male', inputs=input_image)
demo.launch(debug=True)