sbt-idp/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/main.py
2023-11-30 18:22:16 +07:00

148 lines
6.9 KiB
Python

import os
import glob
import cv2
import json
import argparse
import numpy as np
from tqdm import tqdm
from PIL import Image
from datetime import datetime
from sdsvkvu.sources.kvu import KVUEngine
from sdsvkvu.sources.utils import export_kvu_outputs, export_sbt_outputs, draw_kvu_outputs
from sdsvkvu.utils.utils import create_dir, write_to_json, pdf2img
from sdsvkvu.utils.query.vat import export_kvu_for_VAT_invoice, merged_kvu_for_VAT_invoice_for_multi_pages
from sdsvkvu.utils.query.sbt import export_kvu_for_SDSAP, merged_kvu_for_SDSAP_for_multi_pages
from sdsvkvu.utils.query.vtb import export_kvu_for_vietin, merged_kvu_for_vietin_for_multi_pages
from sdsvkvu.utils.query.all import export_kvu_for_all, merged_kvu_for_all_for_multi_pages
from sdsvkvu.utils.query.manulife import export_kvu_for_manulife, merged_kvu_for_manulife_for_multi_pages
from sdsvkvu.utils.query.sbt_v2 import export_kvu_for_SBT, merged_kvu_for_SBT_for_multi_pages
def get_args():
args = argparse.ArgumentParser(description='Main file')
args.add_argument('--img_dir', type=str, required=True,
help='path to input image/directory file')
args.add_argument('--save_dir', type=str, required=True,
help='path to save directory')
args.add_argument('--doc_type', type=str, default="vat",
help='type of document')
args.add_argument('--export_img', type=bool, default=False,
help='export image of output visualization')
args.add_argument('--kvu_params', type=str, required=False, default="")
return args.parse_args()
def load_engine(kwargs) -> KVUEngine:
print('[INFO] Loading Key-Value Understanding model ...')
if not isinstance(kwargs, dict):
kwargs = json.loads(kwargs) if kwargs else {}
engine = KVUEngine(**kwargs)
print("[INFO] Loaded model")
print("[INFO] KVU engine settings: \n", engine._settings)
return engine
def process_img(img_path: str, save_dir: str, engine: KVUEngine, export_all: bool, option: str) -> dict:
assert (engine._settings.mode == 4 and option == "sbt_v2") \
or (engine._settings.mode != 4 and option != "sbt_v2"), \
"[ERROR] Mode (4) has just supported option \"sbt_v2\""
print("="*5, os.path.basename(img_path))
create_dir(save_dir)
fname, img_ext = os.path.splitext(os.path.basename(img_path))
out_ext = ".json"
image, lbbox, lwords, pr_class_words, pr_relations = engine.predict(img_path)
if len(lbbox) != 1:
raise ValueError(
f"Not support to predict each separated window: {len(lbbox)}"
)
for i in range(len(lbbox)):
if engine._settings.mode in range(4):
raw_outputs = export_kvu_outputs(lwords[i], lbbox[i], pr_class_words[i], pr_relations[i], engine._settings.class_names)
elif engine._settings.mode == 4:
raw_outputs = export_sbt_outputs(lwords[i], lbbox[i], pr_class_words[i], pr_relations[i], engine._settings.class_names)
if export_all:
save_path = os.path.join(save_dir, 'kvu_results')
create_dir(save_path)
write_to_json(os.path.join(save_path, fname + out_ext), raw_outputs)
# image = Image.open(img_path)
image = np.array(image)
image = draw_kvu_outputs(image, lbbox[i], pr_class_words[i], pr_relations[i], class_names=engine._settings.class_names)
cv2.imwrite(os.path.join(save_path, fname + img_ext), image)
if option == "vat":
outputs = export_kvu_for_VAT_invoice(raw_outputs)
elif option == "sbt":
outputs = export_kvu_for_SDSAP(raw_outputs)
elif option == "vtb":
outputs = export_kvu_for_vietin(raw_outputs)
elif option == "manulife":
outputs = export_kvu_for_manulife(raw_outputs)
elif option == "sbt_v2":
outputs = export_kvu_for_SBT(raw_outputs)
else:
outputs = export_kvu_for_all(raw_outputs)
write_to_json(os.path.join(save_dir, fname + out_ext), outputs)
return outputs
def process_pdf(pdf_path: str, save_dir: str, engine: KVUEngine, export_all: bool, option: str, n_pages: int = -1) -> dict:
out_ext = ".json"
fname, pdf_ext = os.path.splitext(os.path.basename(pdf_path))
img_dirname = '_'.join([os.path.basename(os.path.dirname(pdf_path)), fname])
img_save_dir = os.path.join(save_dir, img_dirname)
create_dir(img_save_dir)
list_img_files = pdf2img(pdf_path, img_save_dir, n_pages=n_pages, return_fname=True)
outputs = []
for img_path in list_img_files:
print("=====", os.path.basename(img_path))
_outputs = process_img(img_path, img_save_dir, engine, export_all=export_all, option=option)
outputs.append(_outputs)
if option == "vat":
outputs = merged_kvu_for_VAT_invoice_for_multi_pages(outputs)
elif option == "sbt":
outputs = merged_kvu_for_SDSAP_for_multi_pages(outputs)
elif option == "vtb":
outputs = merged_kvu_for_vietin_for_multi_pages(outputs)
elif option == "manulife":
outputs = merged_kvu_for_manulife_for_multi_pages(outputs)
elif option == "sbt_v2":
outputs = merged_kvu_for_SBT_for_multi_pages(outputs)
else:
outputs = merged_kvu_for_all_for_multi_pages(outputs)
write_to_json(os.path.join(save_dir, fname + out_ext), outputs)
return outputs
def process_dir(dir_path: str, save_dir: str, engine: KVUEngine, export_all: bool, option: str, dir_level: int = 0) -> None:
list_images = []
for ext in ['JPG', 'PNG', 'jpeg', 'jpg', 'png', 'pdf']:
list_images += glob.glob(os.path.join(dir_path, f"{'*/'*dir_level}*.{ext}"))
print('No. images:', len(list_images))
for file_path in tqdm(list_images):
if os.path.splitext(file_path)[1] == ".pdf":
outputs = process_pdf(file_path, save_dir, engine, export_all=export_all, option=option, n_pages=-1)
else:
outputs = process_img(file_path, save_dir, engine, export_all=export_all, option=option)
def Predictor_KVU(img: str, save_dir: str, engine: KVUEngine) -> dict:
curr_datetime = datetime.now().strftime('%Y-%m-%d %H-%M-%S')
image_path = "/home/thucpd/thucpd/PV2-2023/tmp_image/{}.jpg".format(curr_datetime)
cv2.imwrite(image_path, img)
vat_outputs = process_img(image_path, save_dir, engine, export_all=False, option="vat")
return vat_outputs
if __name__ == "__main__":
args = get_args()
engine = load_engine(args.kvu_params)
# vat_outputs = process_img(args.img_dir, args.save_dir, engine, export_all=True, option="vat")
# vat_outputs = process_pdf(args.img_dir, args.save_dir, engine, export_all=True, option="vat")
process_dir(args.img_dir, args.save_dir, engine, export_all=args.export_img, option=args.doc_type)
print('[INFO] Done')