sbt-idp/cope2n-ai-fi/api/sdsap_sbt/prediction_sbt.py

106 lines
3.1 KiB
Python
Executable File

import cv2
import nltk
import urllib
import random
import numpy as np
from pathlib import Path
import urllib.parse
import uuid
from copy import deepcopy
import sys, os
cur_dir = str(Path(__file__).parents[2])
sys.path.append(cur_dir)
nltk.data.path.append(os.path.join((os.getcwd() + '/nltk_data')))
from modules.sdsvkvu import load_engine, process_img
from configs.sdsap_sbt import device, ocr_cfg, kvu_cfg
print("OCR engine configfs: \n", ocr_cfg)
print("KVU configfs: \n", kvu_cfg)
# ocr_engine = load_ocr_engine(ocr_cfg)
# kvu_cfg['ocr_engine'] = ocr_engine
kvu_cfg['ocr_configs'] = ocr_cfg
option = kvu_cfg['option']
kvu_cfg.pop("option") # pop option
sbt_engine = load_engine(kvu_cfg)
kvu_cfg["option"] = option
def sbt_predict(image_url, engine, metadata={}) -> None:
req = urllib.request.urlopen(image_url)
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
img = cv2.imdecode(arr, -1)
save_dir = "./tmp_results"
try:
parsed_url = urllib.parse.urlparse(image_url)
query_params = urllib.parse.parse_qs(parsed_url.query)
file_name = query_params['file_name'][0]
except Exception as e:
print(f"[ERROR]: Error extracting file name from url: {image_url}")
file_name = f"{uuid.uuid4()}.jpg"
os.makedirs(save_dir, exist_ok=True)
# image_path = os.path.join(save_dir, f"{image_url}.jpg")
os.makedirs(save_dir, exist_ok = True)
tmp_image_path = os.path.join(save_dir, file_name)
cv2.imwrite(tmp_image_path, img)
extra_params = {'sub': metadata.get("subsidiary", None)} # example of 'AU'
outputs = process_img(img=tmp_image_path,
save_dir=save_dir,
engine=engine,
export_all=False, # False
option=option,
extra_params=extra_params)
os.remove(tmp_image_path)
return outputs
def predict(page_numb, image_url, metadata={}):
"""
module predict function
Args:
image_url (str): image url
Returns:
example output:
"data": {
"document_type": "invoice",
"fields": [
{
"label": "Invoice Number",
"value": "INV-12345",
"box": [0, 0, 0, 0],
"confidence": 0.98
},
...
]
}
dict: output of model
"""
sbt_result = sbt_predict(image_url, engine=sbt_engine, metadata=metadata)
output_dict = {
"document_type": "invoice",
"document_class": " ",
"page_number": page_numb,
"fields": []
}
for key in sbt_result.keys():
field = {
"label": key,
"value": sbt_result[key],
"box": [0, 0, 0, 0],
"confidence": random.uniform(0.9, 1.0),
"page": page_numb
}
output_dict['fields'].append(field)
return output_dict
if __name__ == "__main__":
image_url = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
output = predict(0, image_url)
print(output)