thomas0809 commited on
Commit
bd8cfdf
1 Parent(s): 8ce358c
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import os
4
+ import glob
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ from rxnscribe import RxnScribe
9
+
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ REPO_ID = "yujieq/RxnScribe"
13
+ FILENAME = "pix2seq_reaction_full.ckpt"
14
+ ckpt_path = hf_hub_download(REPO_ID, FILENAME)
15
+
16
+ device = torch.device('cpu')
17
+ model = RxnScribe(ckpt_path, device)
18
+
19
+
20
+ def get_markdown(reaction):
21
+ output = []
22
+ for x in ['reactants', 'conditions', 'products']:
23
+ s = ''
24
+ for ent in reaction[x]:
25
+ if 'smiles' in ent:
26
+ s += ent['smiles'] + '<br>'
27
+ elif 'text' in ent:
28
+ s += ' '.join(ent['text']) + '<br>'
29
+ else:
30
+ s += ent['category']
31
+ output.append(s)
32
+ return output
33
+
34
+
35
+ def predict(image, molscribe, ocr):
36
+ predictions = model.predict_image(image, molscribe=molscribe, ocr=ocr)
37
+ pred_images = model.draw_predictions(predictions, image=image)
38
+ markdown = [[i] + get_markdown(reaction) for i, reaction in enumerate(predictions)]
39
+ return pred_images, markdown
40
+
41
+
42
+ with gr.Blocks() as demo:
43
+ with gr.Column():
44
+ with gr.Row():
45
+ image = gr.Image(label="Upload reaction diagram", show_label=False, type='pil').style(height=256)
46
+ with gr.Row():
47
+ molscribe = gr.Checkbox(label="Run MolScribe to recognize molecule structures")
48
+ ocr = gr.Checkbox(label="Run OCR to recognize text")
49
+ btn = gr.Button("Submit").style(full_width=False)
50
+ with gr.Row():
51
+ gallery = gr.Gallery(
52
+ label="Predicted reactions", show_label=False, elem_id="gallery"
53
+ ).style(height="auto")
54
+ markdown = gr.Dataframe(
55
+ headers=['#', 'reactant', 'condition', 'product'],
56
+ datatype=['number'] + ['markdown'] * 3,
57
+ wrap=False
58
+ )
59
+
60
+ btn.click(predict, inputs=[image, molscribe, ocr], outputs=[gallery, markdown])
61
+
62
+ demo.launch()