sbt-idp/cope2n-ai-fi/common/serve_model.py

94 lines
3.7 KiB
Python
Raw Normal View History

2023-11-30 11:22:16 +00:00
import cv2
from common.ocr import ocr_predict
from common.crop_location import crop_location
from Kie_AHung.prediction import infer_driving_license
from Kie_AHung_ID.prediction import infer_id_card
from common.post_processing_datetime import DatetimeCorrector
from transformers import (
LayoutXLMTokenizer,
LayoutLMv2FeatureExtractor,
LayoutXLMProcessor
)
MAX_SEQ_LENGTH = 512 # TODO Fix this hard code
tokenizer = LayoutXLMTokenizer.from_pretrained(
"./Kie_AHung/model/pretrained/layoutxlm-base/tokenizer", model_max_length=MAX_SEQ_LENGTH
)
feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
processor = LayoutXLMProcessor(feature_extractor, tokenizer)
max_n_words = 100
def predict(rq_id, sub_id, folder_name, list_url, user_id, infer_name):
"""Predict text from image
Args:
image_path (str): path to image
Returns:
dict: dict result of prediction
"""
results = {
"model":{
"name":infer_name,
"confidence": 1.0,
"type": "finance/invoice",
"isValid": True,
"shape": "letter",
}
}
compile_outputs = []
for page in list_url:
image_location = crop_location(page['file_url'])
if image_location is None:
compile_output = {
'page_index': page['page_number'],
'path_image_croped': None,
'request_file_id': page['request_file_id'],
'fields': None
}
compile_outputs.append(compile_output)
elif image_location is not None:
path_image_croped = "/app/media/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id)
cv2.imwrite("/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id), image_location)
list_line = ocr_predict(image_location)
if infer_name == "driving_license":
from common.post_processing_driver import create_kie_dict
_, _, _, list_words = infer_driving_license(image_location, list_line, max_n_words, processor)
result_dict = create_kie_dict(list_words)
elif infer_name == "id_card":
from common.post_processing_id import create_kie_dict
_, _, _, list_words = infer_id_card(image_location, list_line, max_n_words, processor)
result_dict = create_kie_dict(list_words)
fields = []
for kie_label in result_dict:
if result_dict[kie_label]["text"] != "":
if kie_label == "Date Range":
text = DatetimeCorrector.correct(result_dict[kie_label]["text"])
else:
text = result_dict[kie_label]["text"]
field = {
"label": kie_label,
"value": text.replace("", " ") if "" in text else text,
"box": result_dict[kie_label]["bbox"],
"confidence": 0.99 #TODO: add confidence
}
fields.append(field)
compile_output = {
'page_index': page['page_number'],
'path_image_croped': str(path_image_croped),
'request_file_id': page['request_file_id'],
'fields': fields
}
compile_outputs.append(compile_output)
results['pages'] = compile_outputs
return results