94 lines
3.7 KiB
Python
Executable File
94 lines
3.7 KiB
Python
Executable File
import cv2
|
|
from common.ocr import ocr_predict
|
|
from common.crop_location import crop_location
|
|
from Kie_AHung.prediction import infer_driving_license
|
|
from Kie_AHung_ID.prediction import infer_id_card
|
|
from common.post_processing_datetime import DatetimeCorrector
|
|
from transformers import (
|
|
LayoutXLMTokenizer,
|
|
LayoutLMv2FeatureExtractor,
|
|
LayoutXLMProcessor
|
|
)
|
|
|
|
MAX_SEQ_LENGTH = 512 # TODO Fix this hard code
|
|
tokenizer = LayoutXLMTokenizer.from_pretrained(
|
|
"./Kie_AHung/model/pretrained/layoutxlm-base/tokenizer", model_max_length=MAX_SEQ_LENGTH
|
|
)
|
|
|
|
feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
|
|
processor = LayoutXLMProcessor(feature_extractor, tokenizer)
|
|
|
|
max_n_words = 100
|
|
|
|
def predict(rq_id, sub_id, folder_name, list_url, user_id, infer_name):
|
|
"""Predict text from image
|
|
|
|
Args:
|
|
image_path (str): path to image
|
|
|
|
Returns:
|
|
dict: dict result of prediction
|
|
"""
|
|
|
|
results = {
|
|
"model":{
|
|
"name":infer_name,
|
|
"confidence": 1.0,
|
|
"type": "finance/invoice",
|
|
"isValid": True,
|
|
"shape": "letter",
|
|
}
|
|
}
|
|
compile_outputs = []
|
|
for page in list_url:
|
|
image_location = crop_location(page['file_url'])
|
|
if image_location is None:
|
|
compile_output = {
|
|
'page_index': page['page_number'],
|
|
'path_image_croped': None,
|
|
'request_file_id': page['request_file_id'],
|
|
'fields': None
|
|
}
|
|
compile_outputs.append(compile_output)
|
|
|
|
elif image_location is not None:
|
|
path_image_croped = "/app/media/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id)
|
|
cv2.imwrite("/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id), image_location)
|
|
list_line = ocr_predict(image_location)
|
|
|
|
if infer_name == "driving_license":
|
|
from common.post_processing_driver import create_kie_dict
|
|
_, _, _, list_words = infer_driving_license(image_location, list_line, max_n_words, processor)
|
|
result_dict = create_kie_dict(list_words)
|
|
elif infer_name == "id_card":
|
|
from common.post_processing_id import create_kie_dict
|
|
_, _, _, list_words = infer_id_card(image_location, list_line, max_n_words, processor)
|
|
result_dict = create_kie_dict(list_words)
|
|
|
|
fields = []
|
|
for kie_label in result_dict:
|
|
if result_dict[kie_label]["text"] != "":
|
|
if kie_label == "Date Range":
|
|
text = DatetimeCorrector.correct(result_dict[kie_label]["text"])
|
|
else:
|
|
text = result_dict[kie_label]["text"]
|
|
|
|
field = {
|
|
"label": kie_label,
|
|
"value": text.replace("✪", " ") if "✪" in text else text,
|
|
"box": result_dict[kie_label]["bbox"],
|
|
"confidence": 0.99 #TODO: add confidence
|
|
}
|
|
fields.append(field)
|
|
|
|
compile_output = {
|
|
'page_index': page['page_number'],
|
|
'path_image_croped': str(path_image_croped),
|
|
'request_file_id': page['request_file_id'],
|
|
'fields': fields
|
|
}
|
|
|
|
compile_outputs.append(compile_output)
|
|
results['pages'] = compile_outputs
|
|
return results
|