sbt-idp/cope2n-ai-fi/api/Kie_Hoanglv/serve_model.py
2023-11-30 18:22:16 +07:00

83 lines
1.9 KiB
Python
Executable File

from common.utils_invoice.run_ocr import ocr_predict
import os
from Kie_Hoanglv.prediction2 import KIEInvoiceInfer
from configs.config_invoice.layoutxlm_base_invoice import *
from PIL import Image
import requests
from io import BytesIO
import numpy as np
import cv2
model = KIEInvoiceInfer(
weight_dir=TRAINED_DIR,
tokenizer_dir=TOKENIZER_DIR,
max_seq_len=MAX_SEQ_LENGTH,
classes=KIE_LABELS,
device=DEVICE,
outdir_visualize=VISUALIZE_DIR,
)
def format_result(result):
"""
return:
[
{
key: 'name',
value: 'Nguyen Hoang Hiep',
true_box: [
373,
113,
700,
420
]
},
{
key: 'name',
value: 'Nguyen Hoang Hiep 1',
true_box: [
10,
10,
20,
20,
]
},
]
"""
new_result = []
for i, item in enumerate(result[0]):
new_result.append(
{
"key": item,
"value": result[0][item],
"true_box": result[1][i],
}
)
return new_result
def predict(image_url):
if not os.path.exists(PRED_DIR):
os.makedirs(PRED_DIR, exist_ok=True)
if not os.path.exists(VISUALIZE_DIR):
os.makedirs(VISUALIZE_DIR, exist_ok=True)
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))
cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
bboxes, texts = ocr_predict(cv_image)
texts_replaced = []
for text in texts:
if "" in text:
text_replaced = text.replace("", " ")
texts_replaced.append(text_replaced)
else:
texts_replaced.append(text)
inputs = model.prepare_kie_inputs(image, ocr_info=[bboxes, texts_replaced])
result = model(inputs)
result = format_result(result)
return result