sbt-idp/cope2n-ai-fi/api/sdsap_sbt/prediction_sbt.py

114 lines
3.3 KiB
Python
Raw Normal View History

2023-11-30 11:22:16 +00:00
import cv2
2023-12-12 11:51:32 +00:00
import nltk
2023-11-30 11:22:16 +00:00
import urllib
import random
import numpy as np
from pathlib import Path
2024-06-06 07:08:10 +00:00
import urllib.parse
2023-12-05 05:59:06 +00:00
import uuid
2023-12-21 10:31:55 +00:00
from copy import deepcopy
2023-11-30 11:22:16 +00:00
import sys, os
cur_dir = str(Path(__file__).parents[2])
sys.path.append(cur_dir)
2023-12-12 11:51:32 +00:00
nltk.data.path.append(os.path.join((os.getcwd() + '/nltk_data')))
2023-11-30 11:22:16 +00:00
from modules.sdsvkvu import load_engine, process_img
from configs.sdsap_sbt import device, ocr_cfg, kvu_cfg
2024-07-05 13:14:47 +00:00
import logging
import logging.config
from utils.logging.logging import LOGGER_CONFIG
2023-11-30 11:22:16 +00:00
2024-07-05 13:14:47 +00:00
# Load the logging configuration
logging.config.dictConfig(LOGGER_CONFIG)
# Get the logger
logger = logging.getLogger(__name__)
logger.info("OCR engine configfs: \n", ocr_cfg)
logger.info("KVU configfs: \n", kvu_cfg)
2023-11-30 11:22:16 +00:00
2023-12-12 11:51:32 +00:00
# ocr_engine = load_ocr_engine(ocr_cfg)
# kvu_cfg['ocr_engine'] = ocr_engine
kvu_cfg['ocr_configs'] = ocr_cfg
2023-11-30 11:22:16 +00:00
option = kvu_cfg['option']
kvu_cfg.pop("option") # pop option
sbt_engine = load_engine(kvu_cfg)
2023-12-15 10:34:24 +00:00
kvu_cfg["option"] = option
2023-11-30 11:22:16 +00:00
2024-04-05 11:50:41 +00:00
def sbt_predict(image_url, engine, metadata={}) -> None:
2023-11-30 11:22:16 +00:00
req = urllib.request.urlopen(image_url)
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
img = cv2.imdecode(arr, -1)
save_dir = "./tmp_results"
2024-06-06 07:10:38 +00:00
try:
parsed_url = urllib.parse.urlparse(image_url)
query_params = urllib.parse.parse_qs(parsed_url.query)
file_name = query_params['file_name'][0]
except Exception as e:
2024-07-05 13:14:47 +00:00
logger.info(f"[ERROR]: Error extracting file name from url: {image_url}")
2024-06-06 07:10:38 +00:00
file_name = f"{uuid.uuid4()}.jpg"
2024-01-05 07:18:16 +00:00
os.makedirs(save_dir, exist_ok=True)
2023-11-30 11:22:16 +00:00
# image_path = os.path.join(save_dir, f"{image_url}.jpg")
os.makedirs(save_dir, exist_ok = True)
2024-06-06 07:08:10 +00:00
tmp_image_path = os.path.join(save_dir, file_name)
2024-04-05 11:50:41 +00:00
cv2.imwrite(tmp_image_path, img)
extra_params = {'sub': metadata.get("subsidiary", None)} # example of 'AU'
2024-03-12 07:49:17 +00:00
outputs = process_img(img=tmp_image_path,
2023-11-30 11:22:16 +00:00
save_dir=save_dir,
engine=engine,
export_all=False, # False
2024-04-05 11:50:41 +00:00
option=option,
extra_params=extra_params)
2023-12-05 05:59:06 +00:00
os.remove(tmp_image_path)
2023-11-30 11:22:16 +00:00
return outputs
2024-04-05 11:50:41 +00:00
def predict(page_numb, image_url, metadata={}):
2023-11-30 11:22:16 +00:00
"""
module predict function
Args:
image_url (str): image url
Returns:
example output:
"data": {
"document_type": "invoice",
"fields": [
{
"label": "Invoice Number",
"value": "INV-12345",
"box": [0, 0, 0, 0],
"confidence": 0.98
},
...
]
}
dict: output of model
"""
2024-04-05 11:50:41 +00:00
sbt_result = sbt_predict(image_url, engine=sbt_engine, metadata=metadata)
2023-11-30 11:22:16 +00:00
output_dict = {
"document_type": "invoice",
"document_class": " ",
"page_number": page_numb,
"fields": []
}
for key in sbt_result.keys():
field = {
"label": key,
"value": sbt_result[key],
"box": [0, 0, 0, 0],
"confidence": random.uniform(0.9, 1.0),
"page": page_numb
}
output_dict['fields'].append(field)
return output_dict
if __name__ == "__main__":
image_url = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
output = predict(0, image_url)
2024-07-05 13:14:47 +00:00
logger.info(output)