148 lines
6.9 KiB
Python
148 lines
6.9 KiB
Python
import os
|
|
import glob
|
|
import cv2
|
|
import json
|
|
import argparse
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from PIL import Image
|
|
from datetime import datetime
|
|
from sdsvkvu.sources.kvu import KVUEngine
|
|
from sdsvkvu.sources.utils import export_kvu_outputs, export_sbt_outputs, draw_kvu_outputs
|
|
from sdsvkvu.utils.utils import create_dir, write_to_json, pdf2img
|
|
from sdsvkvu.utils.query.vat import export_kvu_for_VAT_invoice, merged_kvu_for_VAT_invoice_for_multi_pages
|
|
from sdsvkvu.utils.query.sbt import export_kvu_for_SDSAP, merged_kvu_for_SDSAP_for_multi_pages
|
|
from sdsvkvu.utils.query.vtb import export_kvu_for_vietin, merged_kvu_for_vietin_for_multi_pages
|
|
from sdsvkvu.utils.query.all import export_kvu_for_all, merged_kvu_for_all_for_multi_pages
|
|
from sdsvkvu.utils.query.manulife import export_kvu_for_manulife, merged_kvu_for_manulife_for_multi_pages
|
|
from sdsvkvu.utils.query.sbt_v2 import export_kvu_for_SBT, merged_kvu_for_SBT_for_multi_pages
|
|
|
|
|
|
def get_args():
|
|
args = argparse.ArgumentParser(description='Main file')
|
|
args.add_argument('--img_dir', type=str, required=True,
|
|
help='path to input image/directory file')
|
|
args.add_argument('--save_dir', type=str, required=True,
|
|
help='path to save directory')
|
|
args.add_argument('--doc_type', type=str, default="vat",
|
|
help='type of document')
|
|
args.add_argument('--export_img', type=bool, default=False,
|
|
help='export image of output visualization')
|
|
args.add_argument('--kvu_params', type=str, required=False, default="")
|
|
return args.parse_args()
|
|
|
|
|
|
def load_engine(kwargs) -> KVUEngine:
|
|
print('[INFO] Loading Key-Value Understanding model ...')
|
|
if not isinstance(kwargs, dict):
|
|
kwargs = json.loads(kwargs) if kwargs else {}
|
|
engine = KVUEngine(**kwargs)
|
|
print("[INFO] Loaded model")
|
|
print("[INFO] KVU engine settings: \n", engine._settings)
|
|
return engine
|
|
|
|
|
|
def process_img(img_path: str, save_dir: str, engine: KVUEngine, export_all: bool, option: str) -> dict:
|
|
assert (engine._settings.mode == 4 and option == "sbt_v2") \
|
|
or (engine._settings.mode != 4 and option != "sbt_v2"), \
|
|
"[ERROR] Mode (4) has just supported option \"sbt_v2\""
|
|
|
|
print("="*5, os.path.basename(img_path))
|
|
create_dir(save_dir)
|
|
fname, img_ext = os.path.splitext(os.path.basename(img_path))
|
|
out_ext = ".json"
|
|
image, lbbox, lwords, pr_class_words, pr_relations = engine.predict(img_path)
|
|
|
|
if len(lbbox) != 1:
|
|
raise ValueError(
|
|
f"Not support to predict each separated window: {len(lbbox)}"
|
|
)
|
|
|
|
for i in range(len(lbbox)):
|
|
if engine._settings.mode in range(4):
|
|
raw_outputs = export_kvu_outputs(lwords[i], lbbox[i], pr_class_words[i], pr_relations[i], engine._settings.class_names)
|
|
elif engine._settings.mode == 4:
|
|
raw_outputs = export_sbt_outputs(lwords[i], lbbox[i], pr_class_words[i], pr_relations[i], engine._settings.class_names)
|
|
|
|
if export_all:
|
|
save_path = os.path.join(save_dir, 'kvu_results')
|
|
create_dir(save_path)
|
|
write_to_json(os.path.join(save_path, fname + out_ext), raw_outputs)
|
|
# image = Image.open(img_path)
|
|
image = np.array(image)
|
|
image = draw_kvu_outputs(image, lbbox[i], pr_class_words[i], pr_relations[i], class_names=engine._settings.class_names)
|
|
cv2.imwrite(os.path.join(save_path, fname + img_ext), image)
|
|
|
|
|
|
if option == "vat":
|
|
outputs = export_kvu_for_VAT_invoice(raw_outputs)
|
|
elif option == "sbt":
|
|
outputs = export_kvu_for_SDSAP(raw_outputs)
|
|
elif option == "vtb":
|
|
outputs = export_kvu_for_vietin(raw_outputs)
|
|
elif option == "manulife":
|
|
outputs = export_kvu_for_manulife(raw_outputs)
|
|
elif option == "sbt_v2":
|
|
outputs = export_kvu_for_SBT(raw_outputs)
|
|
else:
|
|
outputs = export_kvu_for_all(raw_outputs)
|
|
write_to_json(os.path.join(save_dir, fname + out_ext), outputs)
|
|
return outputs
|
|
|
|
|
|
def process_pdf(pdf_path: str, save_dir: str, engine: KVUEngine, export_all: bool, option: str, n_pages: int = -1) -> dict:
|
|
out_ext = ".json"
|
|
fname, pdf_ext = os.path.splitext(os.path.basename(pdf_path))
|
|
img_dirname = '_'.join([os.path.basename(os.path.dirname(pdf_path)), fname])
|
|
img_save_dir = os.path.join(save_dir, img_dirname)
|
|
create_dir(img_save_dir)
|
|
list_img_files = pdf2img(pdf_path, img_save_dir, n_pages=n_pages, return_fname=True)
|
|
outputs = []
|
|
for img_path in list_img_files:
|
|
print("=====", os.path.basename(img_path))
|
|
_outputs = process_img(img_path, img_save_dir, engine, export_all=export_all, option=option)
|
|
outputs.append(_outputs)
|
|
if option == "vat":
|
|
outputs = merged_kvu_for_VAT_invoice_for_multi_pages(outputs)
|
|
elif option == "sbt":
|
|
outputs = merged_kvu_for_SDSAP_for_multi_pages(outputs)
|
|
elif option == "vtb":
|
|
outputs = merged_kvu_for_vietin_for_multi_pages(outputs)
|
|
elif option == "manulife":
|
|
outputs = merged_kvu_for_manulife_for_multi_pages(outputs)
|
|
elif option == "sbt_v2":
|
|
outputs = merged_kvu_for_SBT_for_multi_pages(outputs)
|
|
else:
|
|
outputs = merged_kvu_for_all_for_multi_pages(outputs)
|
|
write_to_json(os.path.join(save_dir, fname + out_ext), outputs)
|
|
return outputs
|
|
|
|
|
|
def process_dir(dir_path: str, save_dir: str, engine: KVUEngine, export_all: bool, option: str, dir_level: int = 0) -> None:
|
|
list_images = []
|
|
for ext in ['JPG', 'PNG', 'jpeg', 'jpg', 'png', 'pdf']:
|
|
list_images += glob.glob(os.path.join(dir_path, f"{'*/'*dir_level}*.{ext}"))
|
|
print('No. images:', len(list_images))
|
|
for file_path in tqdm(list_images):
|
|
if os.path.splitext(file_path)[1] == ".pdf":
|
|
outputs = process_pdf(file_path, save_dir, engine, export_all=export_all, option=option, n_pages=-1)
|
|
else:
|
|
outputs = process_img(file_path, save_dir, engine, export_all=export_all, option=option)
|
|
|
|
|
|
def Predictor_KVU(img: str, save_dir: str, engine: KVUEngine) -> dict:
|
|
curr_datetime = datetime.now().strftime('%Y-%m-%d %H-%M-%S')
|
|
image_path = "/home/thucpd/thucpd/PV2-2023/tmp_image/{}.jpg".format(curr_datetime)
|
|
cv2.imwrite(image_path, img)
|
|
vat_outputs = process_img(image_path, save_dir, engine, export_all=False, option="vat")
|
|
return vat_outputs
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = get_args()
|
|
engine = load_engine(args.kvu_params)
|
|
# vat_outputs = process_img(args.img_dir, args.save_dir, engine, export_all=True, option="vat")
|
|
# vat_outputs = process_pdf(args.img_dir, args.save_dir, engine, export_all=True, option="vat")
|
|
process_dir(args.img_dir, args.save_dir, engine, export_all=args.export_img, option=args.doc_type)
|
|
print('[INFO] Done')
|