sbt-idp/cope2n-ai-fi/api/sdsap_sbt/prediction_sbt.py
2023-12-21 17:31:55 +07:00

104 lines
2.7 KiB
Python
Executable File

import cv2
import nltk
import urllib
import random
import numpy as np
from pathlib import Path
import uuid
from copy import deepcopy
import sys, os
cur_dir = str(Path(__file__).parents[2])
sys.path.append(cur_dir)
nltk.data.path.append(os.path.join((os.getcwd() + '/nltk_data')))
from modules.sdsvkvu import load_engine, process_img
from modules.ocr_engine import OcrEngine
from configs.sdsap_sbt import device, ocr_cfg, kvu_cfg
def load_ocr_engine(opt) -> OcrEngine:
print("[INFO] Loading engine...")
engine = OcrEngine(**opt)
print("[INFO] Engine loaded")
return engine
print("OCR engine configfs: \n", ocr_cfg)
print("KVU configfs: \n", kvu_cfg)
# ocr_engine = load_ocr_engine(ocr_cfg)
# kvu_cfg['ocr_engine'] = ocr_engine
kvu_cfg['ocr_configs'] = ocr_cfg
option = kvu_cfg['option']
kvu_cfg.pop("option") # pop option
sbt_engine = load_engine(kvu_cfg)
kvu_cfg["option"] = option
def sbt_predict(image_url, engine) -> None:
req = urllib.request.urlopen(image_url)
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
img = cv2.imdecode(arr, -1)
save_dir = "./tmp_results"
# image_path = os.path.join(save_dir, f"{image_url}.jpg")
tmp_image_path = os.path.join(save_dir, f"{uuid.uuid4()}.jpg")
cv2.imwrite(tmp_image_path, img)
outputs = process_img(img_path=tmp_image_path,
save_dir=save_dir,
engine=engine,
export_all=False,
option=option)
os.remove(tmp_image_path)
return outputs
def predict(page_numb, image_url):
"""
module predict function
Args:
image_url (str): image url
Returns:
example output:
"data": {
"document_type": "invoice",
"fields": [
{
"label": "Invoice Number",
"value": "INV-12345",
"box": [0, 0, 0, 0],
"confidence": 0.98
},
...
]
}
dict: output of model
"""
sbt_result = sbt_predict(image_url, engine=sbt_engine)
print(sbt_result)
output_dict = {
"document_type": "invoice",
"document_class": " ",
"page_number": page_numb,
"fields": []
}
for key in sbt_result.keys():
field = {
"label": key,
"value": sbt_result[key],
"box": [0, 0, 0, 0],
"confidence": random.uniform(0.9, 1.0),
"page": page_numb
}
output_dict['fields'].append(field)
return output_dict
if __name__ == "__main__":
image_url = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
output = predict(0, image_url)
print(output)