vincentclaes commited on
Commit
f0a21d6
1 Parent(s): 05957fd

refactor code

Browse files
Files changed (1) hide show
  1. app.py +82 -114
app.py CHANGED
@@ -2,6 +2,7 @@ import io
2
  import os
3
  import boto3
4
  import traceback
 
5
 
6
  import gradio as gr
7
  from PIL import Image, ImageDraw
@@ -10,43 +11,37 @@ from docquery.document import load_document, ImageDocument
10
  from docquery.ocr_reader import get_ocr_reader
11
  from transformers import AutoTokenizer, AutoModelForQuestionAnswering
12
  from transformers import DonutProcessor, VisionEncoderDecoderModel
 
13
 
14
  # avoid ssl errors
15
  import ssl
 
16
  ssl._create_default_https_context = ssl._create_unverified_context
17
 
18
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
 
 
20
 
21
- def ensure_list(x):
22
- if isinstance(x, list):
23
- return x
24
- else:
25
- return [x]
26
-
 
 
27
 
28
- CHECKPOINTS = {
29
- # "LayoutLMv1 🦉": "impira/layoutlm-document-qa",
30
- # "LayoutLMv1 for Invoices 💸": "impira/layoutlm-invoices",
31
- "Textract Query": "Textract",
32
- "LayoutLM FineTuned": "LayoutLM FineTuned",
33
- "Donut": "naver-clova-ix/donut-base-finetuned-rvlcdip",
34
- "LiLT": "philschmid/lilt-en-funsd",
35
- # "LiLT" : "nielsr/lilt-xlm-roberta-base"
36
- }
37
 
38
- PIPELINES = {}
39
- #
40
- #
41
- # def construct_pipeline(task, model):
42
- # global PIPELINES
43
- # if model in PIPELINES:
44
- # return PIPELINES[model]
45
- #
46
- # device = "cuda" if torch.cuda.is_available() else "cpu"
47
- # ret = pipeline(task=task, model=CHECKPOINTS[model], device=device)
48
- # PIPELINES[model] = ret
49
- # return ret
50
 
51
 
52
  def image_to_byte_array(image: Image) -> bytes:
@@ -56,25 +51,25 @@ def image_to_byte_array(image: Image) -> bytes:
56
  return image_as_byte_array
57
 
58
 
59
- def run_textract_query(question, document):
60
  image_as_byte_base64 = image_to_byte_array(image=document.b)
61
- response = boto3.client('textract').analyze_document(
62
  Document={
63
- 'Bytes': image_as_byte_base64,
64
  },
65
  FeatureTypes=[
66
- 'QUERIES',
67
  ],
68
  QueriesConfig={
69
- 'Queries': [
70
  {
71
- 'Text': question,
72
- 'Pages': [
73
- '*',
74
- ]
75
  },
76
  ]
77
- }
78
  )
79
  for element in response["Blocks"]:
80
  if element["BlockType"] == "QUERY_RESULT":
@@ -87,75 +82,60 @@ def run_textract_query(question, document):
87
  Exception("No QUERY_RESULT found in the response from Textract.")
88
 
89
 
90
- def run_layoutlm_finetuned(question, document):
91
- from transformers import pipeline
92
-
93
- nlp = pipeline(
94
- "document-question-answering",
95
- model="impira/layoutlm-document-qa",
96
- )
97
-
98
- result = nlp(document.context["image"][0][0], question)[0]
99
  # [{'score': 0.9999411106109619, 'answer': 'LETTER OF CREDIT', 'start': 106, 'end': 108}]
100
  return {
101
  "score": result["score"],
102
  "answer": result["answer"],
103
  "word_ids": [result["start"], result["end"]],
104
- "page": 0
105
  }
106
 
107
 
108
- def run_lilt_model(question, document):
109
-
110
  # use this model + tokenizer
111
- lilt_tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-infoxlm-base")
112
- model = AutoModelForQuestionAnswering.from_pretrained("nielsr/lilt-xlm-roberta-base")
113
-
114
  processed_document = document.context["image"][0][1]
115
  words = [x[0] for x in processed_document]
116
  boxes = [x[1] for x in processed_document]
117
 
118
- encoding = lilt_tokenizer(text=question, text_pair=words, boxes=boxes, add_special_tokens=True, return_tensors="pt")
119
-
120
- outputs = model(**encoding)
 
 
 
 
 
121
 
122
  answer_start_index = outputs.start_logits.argmax()
123
  answer_end_index = outputs.end_logits.argmax()
124
 
125
- predict_answer_tokens = encoding.input_ids[0, answer_start_index: answer_end_index + 1]
126
- predict_answer = lilt_tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)
 
 
 
 
127
  return {
128
- "score": "n/a",
129
- "answer": predict_answer,
130
- # "word_ids": element
131
- }
132
 
133
 
134
  def run_donut(question, document):
135
-
136
- # nlp = pipeline(
137
- # "document-question-answering",
138
- # model="naver-clova-ix/donut-base-finetuned-docvqa",
139
- # )
140
- #
141
- # result = nlp(document.context["image"][0][0], question)[0]
142
- # # [{'score': 0.9999411106109619, 'answer': 'LETTER OF CREDIT', 'start': 106, 'end': 108}]
143
- # return {
144
- # "score": result["score"],
145
- # "answer": result["answer"],
146
- # "word_ids": [result["start"], result["end"]],
147
- # "page": 0
148
- # }
149
-
150
- donut_processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
151
- donut_model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
152
  # prepare encoder inputs
153
- pixel_values = donut_processor(document.context["image"][0][0], return_tensors="pt").pixel_values
 
 
154
 
155
  # prepare decoder inputs
156
  task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
157
  prompt = task_prompt.replace("{user_input}", question)
158
- decoder_input_ids = donut_processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
 
 
159
 
160
  # generate answer
161
  outputs = donut_model.generate(
@@ -170,11 +150,13 @@ def run_donut(question, document):
170
  bad_words_ids=[[donut_processor.tokenizer.unk_token_id]],
171
  return_dict_in_generate=True,
172
  )
173
- import re
174
- # postprocess
175
  sequence = donut_processor.batch_decode(outputs.sequences)[0]
176
- sequence = sequence.replace(donut_processor.tokenizer.eos_token, "").replace(donut_processor.tokenizer.pad_token, "")
177
- sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
 
 
 
 
178
 
179
  result = donut_processor.token2json(sequence)
180
  return {
@@ -184,26 +166,6 @@ def run_donut(question, document):
184
  }
185
 
186
 
187
- def run_pipeline(model, question, document, top_k):
188
- """ Run pipeline selected by the user.
189
- :return: expect an object like
190
- [{'score': 0.251716673374176, 'answer': 'CREDIT', 'word_ids': [38], 'page': 0},
191
- {'score': 0.15292450785636902, 'answer': 'LETTER OF CREDIT', 'word_ids': [37, 38], 'page': 0},
192
- {'score': 0.009600160643458366, 'answer': 'Payment Tens LETTER OF CREDIT', 'word_ids': [36, 37, 38], 'page': 0}]
193
- """
194
- if model == "Textract Query":
195
- return run_textract_query(question, document)
196
- elif model == "LiLT":
197
- return run_lilt_model(question, document)
198
- elif model == "LayoutLM FineTuned":
199
- return run_layoutlm_finetuned(question=question, document=document)
200
- elif model == "Donut":
201
- return run_donut(question=question, document=document)
202
- else:
203
- return {"answer": "model not found", "score": "n/a"}
204
-
205
-
206
-
207
  def process_path(path):
208
  error = None
209
  if path:
@@ -230,6 +192,7 @@ def process_path(path):
230
  None,
231
  )
232
 
 
233
  def process_upload(file):
234
  if file:
235
  return process_path(file.name)
@@ -268,11 +231,19 @@ def normalize_bbox(box, width, height, padding=0.005):
268
  return [min_x * width, min_y * height, max_x * width, max_y * height]
269
 
270
 
271
- def process_question(question, document, model=list(CHECKPOINTS.keys())[0]):
272
- prediction = run_pipeline(model, question, document, 3)
273
- pages = [x.copy().convert("RGB") for x in document.preview]
 
 
 
 
 
 
 
274
  text_value = prediction["answer"]
275
  if "word_ids" in prediction:
 
276
  image = pages[prediction["page"]]
277
  draw = ImageDraw.Draw(image, "RGBA")
278
  word_boxes = lift_word_boxes(document, prediction["page"])
@@ -397,7 +368,6 @@ gradio-app h2, .gradio-app h2 {
397
  """
398
 
399
  examples = [
400
-
401
  [
402
  "scenario-1.png",
403
  "What is the final consignee?",
@@ -416,7 +386,7 @@ examples = [
416
  ],
417
  [
418
  "scenario-4.png",
419
- 'What is the color?',
420
  ],
421
  [
422
  "scenario-5.png",
@@ -458,9 +428,7 @@ examples = [
458
 
459
  with gr.Blocks(css=CSS) as demo:
460
  gr.Markdown("# Document Query Engine")
461
- gr.Markdown(
462
- "Original version comes from DocQuery [here](https://huggingface.co/spaces/impira/docquery) (created by [Impira](https://impira.com?utm_source=huggingface&utm_medium=referral&utm_campaign=docquery_space))"
463
- )
464
 
465
  document = gr.Variable()
466
  example_question = gr.Textbox(visible=False)
@@ -489,8 +457,8 @@ with gr.Blocks(css=CSS) as demo:
489
  max_lines=1,
490
  )
491
  model = gr.Radio(
492
- choices=list(CHECKPOINTS.keys()),
493
- value=list(CHECKPOINTS.keys())[0],
494
  label="Model",
495
  )
496
 
 
2
  import os
3
  import boto3
4
  import traceback
5
+ import re
6
 
7
  import gradio as gr
8
  from PIL import Image, ImageDraw
 
11
  from docquery.ocr_reader import get_ocr_reader
12
  from transformers import AutoTokenizer, AutoModelForQuestionAnswering
13
  from transformers import DonutProcessor, VisionEncoderDecoderModel
14
+ from transformers import pipeline
15
 
16
  # avoid ssl errors
17
  import ssl
18
+
19
  ssl._create_default_https_context = ssl._create_unverified_context
20
 
21
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
22
 
23
+ # Init models
24
 
25
+ layoutlm_pipeline = pipeline(
26
+ "document-question-answering",
27
+ model="impira/layoutlm-document-qa",
28
+ )
29
+ lilt_tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-infoxlm-base")
30
+ lilt_model = AutoModelForQuestionAnswering.from_pretrained(
31
+ "nielsr/lilt-xlm-roberta-base"
32
+ )
33
 
34
+ donut_processor = DonutProcessor.from_pretrained(
35
+ "naver-clova-ix/donut-base-finetuned-docvqa"
36
+ )
37
+ donut_model = VisionEncoderDecoderModel.from_pretrained(
38
+ "naver-clova-ix/donut-base-finetuned-docvqa"
39
+ )
 
 
 
40
 
41
+ TEXTRACT = "Textract Query"
42
+ LAYOUTLM = "LayoutLM"
43
+ DONUT = "Donut"
44
+ LILT = "LiLT"
 
 
 
 
 
 
 
 
45
 
46
 
47
  def image_to_byte_array(image: Image) -> bytes:
 
51
  return image_as_byte_array
52
 
53
 
54
+ def run_textract(question, document):
55
  image_as_byte_base64 = image_to_byte_array(image=document.b)
56
+ response = boto3.client("textract").analyze_document(
57
  Document={
58
+ "Bytes": image_as_byte_base64,
59
  },
60
  FeatureTypes=[
61
+ "QUERIES",
62
  ],
63
  QueriesConfig={
64
+ "Queries": [
65
  {
66
+ "Text": question,
67
+ "Pages": [
68
+ "*",
69
+ ],
70
  },
71
  ]
72
+ },
73
  )
74
  for element in response["Blocks"]:
75
  if element["BlockType"] == "QUERY_RESULT":
 
82
  Exception("No QUERY_RESULT found in the response from Textract.")
83
 
84
 
85
+ def run_layoutlm(question, document):
86
+ result = layoutlm_pipeline(document.context["image"][0][0], question)[0]
 
 
 
 
 
 
 
87
  # [{'score': 0.9999411106109619, 'answer': 'LETTER OF CREDIT', 'start': 106, 'end': 108}]
88
  return {
89
  "score": result["score"],
90
  "answer": result["answer"],
91
  "word_ids": [result["start"], result["end"]],
92
+ "page": 0,
93
  }
94
 
95
 
96
+ def run_lilt(question, document):
 
97
  # use this model + tokenizer
 
 
 
98
  processed_document = document.context["image"][0][1]
99
  words = [x[0] for x in processed_document]
100
  boxes = [x[1] for x in processed_document]
101
 
102
+ encoding = lilt_tokenizer(
103
+ text=question,
104
+ text_pair=words,
105
+ boxes=boxes,
106
+ add_special_tokens=True,
107
+ return_tensors="pt",
108
+ )
109
+ outputs = lilt_model(**encoding)
110
 
111
  answer_start_index = outputs.start_logits.argmax()
112
  answer_end_index = outputs.end_logits.argmax()
113
 
114
+ predict_answer_tokens = encoding.input_ids[
115
+ 0, answer_start_index: answer_end_index + 1
116
+ ]
117
+ predict_answer = lilt_tokenizer.decode(
118
+ predict_answer_tokens, skip_special_tokens=True
119
+ )
120
  return {
121
+ "score": "n/a",
122
+ "answer": predict_answer,
123
+ # "word_ids": element
124
+ }
125
 
126
 
127
  def run_donut(question, document):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  # prepare encoder inputs
129
+ pixel_values = donut_processor(
130
+ document.context["image"][0][0], return_tensors="pt"
131
+ ).pixel_values
132
 
133
  # prepare decoder inputs
134
  task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
135
  prompt = task_prompt.replace("{user_input}", question)
136
+ decoder_input_ids = donut_processor.tokenizer(
137
+ prompt, add_special_tokens=False, return_tensors="pt"
138
+ ).input_ids
139
 
140
  # generate answer
141
  outputs = donut_model.generate(
 
150
  bad_words_ids=[[donut_processor.tokenizer.unk_token_id]],
151
  return_dict_in_generate=True,
152
  )
 
 
153
  sequence = donut_processor.batch_decode(outputs.sequences)[0]
154
+ sequence = sequence.replace(donut_processor.tokenizer.eos_token, "").replace(
155
+ donut_processor.tokenizer.pad_token, ""
156
+ )
157
+ sequence = re.sub(
158
+ r"<.*?>", "", sequence, count=1
159
+ ).strip() # remove first task start token
160
 
161
  result = donut_processor.token2json(sequence)
162
  return {
 
166
  }
167
 
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  def process_path(path):
170
  error = None
171
  if path:
 
192
  None,
193
  )
194
 
195
+
196
  def process_upload(file):
197
  if file:
198
  return process_path(file.name)
 
231
  return [min_x * width, min_y * height, max_x * width, max_y * height]
232
 
233
 
234
+ MODELS = {
235
+ TEXTRACT: run_textract,
236
+ LAYOUTLM: run_layoutlm,
237
+ DONUT: run_donut,
238
+ LILT: run_lilt,
239
+ }
240
+
241
+
242
+ def process_question(question, document, model=list(MODELS.keys())[0]):
243
+ prediction = MODELS[model](question=question, document=document)
244
  text_value = prediction["answer"]
245
  if "word_ids" in prediction:
246
+ pages = [x.copy().convert("RGB") for x in document.preview]
247
  image = pages[prediction["page"]]
248
  draw = ImageDraw.Draw(image, "RGBA")
249
  word_boxes = lift_word_boxes(document, prediction["page"])
 
368
  """
369
 
370
  examples = [
 
371
  [
372
  "scenario-1.png",
373
  "What is the final consignee?",
 
386
  ],
387
  [
388
  "scenario-4.png",
389
+ "What is the color?",
390
  ],
391
  [
392
  "scenario-5.png",
 
428
 
429
  with gr.Blocks(css=CSS) as demo:
430
  gr.Markdown("# Document Query Engine")
431
+ gr.Markdown("### Compare performance of different document layout models. If you have any suggestions [contact me](https://www.linkedin.com/in/vincent-claes-0b346337/)")
 
 
432
 
433
  document = gr.Variable()
434
  example_question = gr.Textbox(visible=False)
 
457
  max_lines=1,
458
  )
459
  model = gr.Radio(
460
+ choices=list(MODELS.keys()),
461
+ value=list(MODELS.keys())[0],
462
  label="Model",
463
  )
464