Xp-age / app.py
MedicalAILabo's picture
Upload app.py and lib.
1f53a4c
raw
history blame
No virus
1.65 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
import gradio as gr
from lib import create_model
from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group
from lib.dataloader import ImageMixin
test_weight = './weight_epoch-200_best.pt'
parameter = './parameters.json'
class ImageHandler(ImageMixin):
def __init__(self, params):
self.params = params
self.transform = self._make_transforms()
def set_image(self, image):
image = self.transform(image)
image = {'image': image.unsqueeze(0)}
return image
def load_parameter(parameter):
_args = ParamSet()
params = _retrieve_parameter(parameter)
for _param, _arg in params.items():
setattr(_args, _param, _arg)
_args.augmentation = 'no'
_args.sampler = 'no'
_args.pretrained = False
_args.mlp = None
_args.net = _args.model
_args.device = torch.device('cpu')
args_model = _dispatch_by_group(_args, 'model')
args_dataloader = _dispatch_by_group(_args, 'dataloader')
return args_model, args_dataloader
args_model, args_dataloader = load_parameter(parameter)
model = create_model(args_model)
model.load_weight(test_weight)
def main(image):
model.eval()
image_handler = ImageHandler(args_dataloader)
image = image_handler.set_image(image)
with torch.no_grad():
outputs = model(image)
label_name = list(outputs.keys())[0]
result = outputs[label_name].detach().numpy().item()
result = f"{result:.2f}"
return result
# Gradio
iface = gr.Interface(fn=main, inputs=[gr.Image(type='pil', image_mode='L')], outputs='text')
iface.launch()