Felix92 commited on
Commit
13abb86
1 Parent(s): f062ee2

upload cpu version

Browse files
Files changed (4) hide show
  1. README.md +1 -1
  2. app.py +217 -0
  3. packages.txt +1 -0
  4. requirements.txt +2 -0
README.md CHANGED
@@ -2,7 +2,7 @@
2
  title: OnnxTR OCR
3
  emoji: 🔥
4
  colorFrom: red
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.37.1
8
  app_file: app.py
 
2
  title: OnnxTR OCR
3
  emoji: 🔥
4
  colorFrom: red
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.37.1
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from typing import Any, List, Union
3
+
4
+ import cv2
5
+ import gradio as gr
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ from matplotlib.figure import Figure
9
+ from onnxtr.io import DocumentFile
10
+ from onnxtr.models import ocr_predictor
11
+ from onnxtr.models.predictor import OCRPredictor
12
+ from onnxtr.utils.visualization import visualize_page
13
+ from PIL import Image
14
+
15
+ DET_ARCHS: List[str] = [
16
+ "fast_base",
17
+ "fast_small",
18
+ "fast_tiny",
19
+ "db_resnet50",
20
+ "db_resnet34",
21
+ "db_mobilenet_v3_large",
22
+ "linknet_resnet18",
23
+ "linknet_resnet34",
24
+ "linknet_resnet50",
25
+ ]
26
+ RECO_ARCHS: List[str] = [
27
+ "crnn_vgg16_bn",
28
+ "crnn_mobilenet_v3_small",
29
+ "crnn_mobilenet_v3_large",
30
+ "master",
31
+ "sar_resnet31",
32
+ "vitstr_small",
33
+ "vitstr_base",
34
+ "parseq",
35
+ ]
36
+
37
+
38
+ def load_predictor(
39
+ det_arch: str,
40
+ reco_arch: str,
41
+ assume_straight_pages: bool,
42
+ straighten_pages: bool,
43
+ bin_thresh: float,
44
+ box_thresh: float,
45
+ ) -> OCRPredictor:
46
+ """Load a predictor from doctr.models
47
+
48
+ Args:
49
+ ----
50
+ det_arch: detection architecture
51
+ reco_arch: recognition architecture
52
+ assume_straight_pages: whether to assume straight pages or not
53
+ straighten_pages: whether to straighten rotated pages or not
54
+ bin_thresh: binarization threshold for the segmentation map
55
+ box_thresh: minimal objectness score to consider a box
56
+
57
+ Returns:
58
+ -------
59
+ instance of OCRPredictor
60
+ """
61
+ predictor = ocr_predictor(
62
+ det_arch,
63
+ reco_arch,
64
+ assume_straight_pages=assume_straight_pages,
65
+ straighten_pages=straighten_pages,
66
+ export_as_straight_boxes=straighten_pages,
67
+ detect_orientation=not assume_straight_pages,
68
+ )
69
+ predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
70
+ predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
71
+ return predictor
72
+
73
+
74
+ def forward_image(predictor: OCRPredictor, image: np.ndarray) -> np.ndarray:
75
+ """Forward an image through the predictor
76
+
77
+ Args:
78
+ ----
79
+ predictor: instance of OCRPredictor
80
+ image: image to process
81
+
82
+ Returns:
83
+ -------
84
+ segmentation map
85
+ """
86
+ processed_batches = predictor.det_predictor.pre_processor([image])
87
+ out = predictor.det_predictor.model(processed_batches[0], return_model_output=True)
88
+ seg_map = out["out_map"]
89
+
90
+ return seg_map
91
+
92
+
93
+ def matplotlib_to_pil(fig: Union[Figure, np.ndarray]) -> Image.Image:
94
+ """ Convert a matplotlib figure to a PIL image
95
+
96
+ Args:
97
+ ----
98
+ fig: matplotlib figure or numpy array
99
+
100
+ Returns:
101
+ -------
102
+ PIL image
103
+ """
104
+ buf = io.BytesIO()
105
+ if isinstance(fig, Figure):
106
+ fig.savefig(buf)
107
+ else:
108
+ plt.imsave(buf, fig)
109
+ buf.seek(0)
110
+ return Image.open(buf)
111
+
112
+
113
+ def analyze_page(
114
+ uploaded_file: Any,
115
+ page_idx: int,
116
+ det_arch: str,
117
+ reco_arch: str,
118
+ assume_straight_pages: bool,
119
+ straighten_pages: bool,
120
+ bin_thresh: float,
121
+ box_thresh: float,
122
+ ):
123
+ """ Analyze a page
124
+
125
+ Args:
126
+ ----
127
+ uploaded_file: file to analyze
128
+ page_idx: index of the page to analyze
129
+ det_arch: detection architecture
130
+ reco_arch: recognition architecture
131
+ assume_straight_pages: whether to assume straight pages or not
132
+ straighten_pages: whether to straighten rotated pages or not
133
+ bin_thresh: binarization threshold for the segmentation map
134
+ box_thresh: minimal objectness score to consider a box
135
+
136
+ Returns:
137
+ -------
138
+ input image, segmentation heatmap, output image, OCR output
139
+ """
140
+ if uploaded_file is None:
141
+ return None, "Please upload a document", None, None, None
142
+
143
+ if uploaded_file.name.endswith(".pdf"):
144
+ doc = DocumentFile.from_pdf(uploaded_file)
145
+ else:
146
+ doc = DocumentFile.from_images(uploaded_file)
147
+
148
+ page = doc[page_idx - 1]
149
+ img = page
150
+
151
+ predictor = load_predictor(
152
+ det_arch,
153
+ reco_arch,
154
+ assume_straight_pages,
155
+ straighten_pages,
156
+ bin_thresh,
157
+ box_thresh,
158
+ )
159
+
160
+ seg_map = forward_image(predictor, page)
161
+ seg_map = np.squeeze(seg_map)
162
+ seg_map = cv2.resize(seg_map, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_LINEAR)
163
+ seg_heatmap = matplotlib_to_pil(seg_map)
164
+
165
+ out = predictor([page])
166
+
167
+ page_export = out.pages[0].export()
168
+ fig = visualize_page(out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=False)
169
+
170
+ out_img = matplotlib_to_pil(fig)
171
+
172
+ return img, seg_heatmap, out_img, page_export
173
+
174
+
175
+ with gr.Blocks(fill_height=True) as demo:
176
+ gr.Markdown("# **OnnxTR OCR demo**")
177
+ gr.Markdown("### This demo showcases the OCR capabilities of OnnxTR. **Github**: [OnnxTR](https://github.com/felixdittrich92/OnnxTR)")
178
+ with gr.Row():
179
+ with gr.Column(scale=1):
180
+ upload = gr.File(label="Upload File [JPG | PNG | PDF]", file_types=["pdf", "jpg", "png"])
181
+ page_selection = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Page selection")
182
+ det_model = gr.Dropdown(choices=DET_ARCHS, value=DET_ARCHS[0], label="Text detection model")
183
+ reco_model = gr.Dropdown(choices=RECO_ARCHS, value=RECO_ARCHS[0], label="Text recognition model")
184
+ assume_straight = gr.Checkbox(value=True, label="Assume straight pages")
185
+ straighten = gr.Checkbox(value=False, label="Straighten pages")
186
+ binarization_threshold = gr.Slider(
187
+ minimum=0.1, maximum=0.9, value=0.3, step=0.1, label="Binarization threshold"
188
+ )
189
+ box_threshold = gr.Slider(minimum=0.1, maximum=0.9, value=0.1, step=0.1, label="Box threshold")
190
+ analyze_button = gr.Button("Analyze page")
191
+ with gr.Column(scale=3):
192
+ with gr.Row():
193
+ input_image = gr.Image(label="Input page", width=600)
194
+ segmentation_heatmap = gr.Image(label="Segmentation heatmap", width=600)
195
+ output_image = gr.Image(label="Output page", width=600)
196
+ with gr.Column(scale=2):
197
+ with gr.Row():
198
+ gr.Markdown("### OCR output")
199
+ with gr.Row():
200
+ ocr_output = gr.JSON(label="OCR output", render=True, scale=1)
201
+
202
+ analyze_button.click(
203
+ analyze_page,
204
+ inputs=[
205
+ upload,
206
+ page_selection,
207
+ det_model,
208
+ reco_model,
209
+ assume_straight,
210
+ straighten,
211
+ binarization_threshold,
212
+ box_threshold,
213
+ ],
214
+ outputs=[input_image, segmentation_heatmap, output_image, ocr_output],
215
+ )
216
+
217
+ demo.launch(inbrowser=True)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python3-opencv
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ -e git+https://github.com/felixdittrich92/OnnxTR.git#egg=onnxtr[cpu,viz]
2
+ gradio>=4.37.1,<5.0.0