134 lines
6.5 KiB
Python
Executable File
134 lines
6.5 KiB
Python
Executable File
import os
|
|
import glob
|
|
import cv2
|
|
import argparse
|
|
from tqdm import tqdm
|
|
import urllib
|
|
import numpy as np
|
|
import imagesize
|
|
# from omegaconf import OmegaConf
|
|
import sys
|
|
cur_dir = os.path.dirname(__file__)
|
|
sys.path.append(cur_dir)
|
|
# sys.path.append('/cope2n-ai-fi/Kie_Invoice_AP/AnyKey_Value/')
|
|
from predictor import KVUPredictor
|
|
from preprocess import KVUProcess, DocumentKVUProcess
|
|
from utils.utils import create_dir, visualize, get_colormap, export_kvu_outputs, export_kvu_for_manulife
|
|
|
|
|
|
def get_args():
|
|
args = argparse.ArgumentParser(description='Main file')
|
|
args.add_argument('--img_dir', default='/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/visualize/test/', type=str,
|
|
help='Input image directory')
|
|
args.add_argument('--save_dir', default='/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/visualize/test/', type=str,
|
|
help='Save directory')
|
|
# args.add_argument('--exp_dir', default='/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/experiments/key_value_understanding-20230608-171900', type=str,
|
|
# help='Checkpoint and config of model')
|
|
args.add_argument('--exp_dir', default='/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/experiments/key_value_understanding-20230608-171900', type=str,
|
|
help='Checkpoint and config of model')
|
|
args.add_argument('--export_img', default=0, type=int,
|
|
help='Save visualize on image')
|
|
args.add_argument('--mode', default=3, type=int,
|
|
help="0:'normal' - 1:'full_tokens' - 2:'sliding' - 3: 'document'")
|
|
args.add_argument('--dir_level', default=0, type=int,
|
|
help='Number of subfolders contains image')
|
|
|
|
return args.parse_args()
|
|
|
|
|
|
def load_engine(exp_dir: str, class_names: list, mode: int) -> KVUPredictor:
|
|
configs = {
|
|
'cfg': glob.glob(f'{exp_dir}/*.yaml')[0],
|
|
'ckpt': f'{exp_dir}/checkpoints/best_model.pth'
|
|
}
|
|
dummy_idx = 512
|
|
predictor = KVUPredictor(configs, class_names, dummy_idx, mode)
|
|
|
|
# processor = KVUProcess(predictor.net.tokenizer_layoutxlm,
|
|
# predictor.net.feature_extractor, predictor.backbone_type, class_names,
|
|
# predictor.slice_interval, predictor.window_size, run_ocr=1, mode=mode)
|
|
|
|
processor = DocumentKVUProcess(predictor.net.tokenizer, predictor.net.feature_extractor,
|
|
predictor.backbone_type, class_names,
|
|
predictor.max_window_count, predictor.slice_interval, predictor.window_size,
|
|
run_ocr=1, mode=mode)
|
|
return predictor, processor
|
|
|
|
def revert_box(box, width, height):
|
|
return [
|
|
int((box[0] / 1000) * width),
|
|
int((box[1] / 1000) * height),
|
|
int((box[2] / 1000) * width),
|
|
int((box[3] / 1000) * height)
|
|
]
|
|
|
|
def predict_image(img_path: str, save_dir: str, predictor: KVUPredictor, processor) -> None:
|
|
fname = os.path.basename(img_path)
|
|
img_ext = img_path.split('.')[-1]
|
|
inputs = processor(img_path, ocr_path='')
|
|
width, height = imagesize.get(img_path)
|
|
|
|
bbox, lwords, pr_class_words, pr_relations = predictor.predict(inputs)
|
|
# slide_window = False if len(bbox) == 1 else True
|
|
|
|
if len(bbox) == 0:
|
|
bbox, lwords, pr_class_words, pr_relations = [bbox], [lwords], [pr_class_words], [pr_relations]
|
|
|
|
for i in range(len(bbox)):
|
|
bbox[i] = [revert_box(bb, width, height) for bb in bbox[i]]
|
|
# vat_outputs_invoice = export_kvu_for_VAT_invoice(os.path.join(save_dir, fname.replace(f'.{img_ext}', '.json')), lwords[i], pr_class_words[i], pr_relations[i], predictor.class_names)
|
|
# vat_outputs_receipt = export_kvu_for_SDSAP(os.path.join(save_dir, fname.replace(f'.{img_ext}', '.json')), lwords[i], pr_class_words[i], pr_relations[i], predictor.class_names)
|
|
# vat_outputs_invoice = export_kvu_for_all(os.path.join(save_dir, fname.replace(img_ext, '.json')), lwords[i], bbox[i], pr_class_words[i], pr_relations[i], predictor.class_names)
|
|
vat_outputs_invoice = export_kvu_for_manulife(os.path.join(save_dir, fname.replace(img_ext, '.json')), lwords[i], bbox[i], pr_class_words[i], pr_relations[i], predictor.class_names)
|
|
|
|
return vat_outputs_invoice
|
|
|
|
|
|
def load_groundtruth(img_path: str, json_dir: str, save_dir: str, predictor: KVUPredictor, processor: KVUProcess, export_img: int) -> None:
|
|
fname = os.path.basename(img_path)
|
|
img_ext = img_path.split('.')[-1]
|
|
inputs = processor.load_ground_truth(os.path.join(json_dir, fname.replace(f".{img_ext}", ".json")))
|
|
bbox, lwords, pr_class_words, pr_relations = predictor.get_ground_truth_label(inputs)
|
|
|
|
export_kvu_outputs(os.path.join(save_dir, fname.replace(f'.{img_ext}', '.json')), lwords, pr_class_words, pr_relations, predictor.class_names)
|
|
|
|
if export_img == 1:
|
|
save_path = os.path.join(save_dir, 'kvu_results')
|
|
create_dir(save_path)
|
|
color_map = get_colormap()
|
|
image = cv2.imread(img_path)
|
|
image = visualize(image, bbox, pr_class_words, pr_relations, color_map, class_names, thickness=1)
|
|
cv2.imwrite(os.path.join(save_path, fname), image)
|
|
|
|
def show_groundtruth(dir_path: str, json_dir: str, save_dir: str, predictor: KVUPredictor, processor, export_img: int) -> None:
|
|
list_images = []
|
|
for ext in ['JPG', 'PNG', 'jpeg', 'jpg', 'png']:
|
|
list_images += glob.glob(os.path.join(dir_path, f'*.{ext}'))
|
|
for img_path in tqdm(list_images):
|
|
load_groundtruth(img_path, json_dir, save_dir, predictor, processor, export_img)
|
|
|
|
def Predictor_KVU(image_url: str, save_dir: str, predictor: KVUPredictor, processor) -> None:
|
|
|
|
req = urllib.request.urlopen(image_url)
|
|
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
|
|
img = cv2.imdecode(arr, -1)
|
|
image_path = "./Kie_Invoice_AP/tmp_image/{image_url}.jpg"
|
|
cv2.imwrite(image_path, img)
|
|
vat_outputs = predict_image(image_path, save_dir, predictor, processor)
|
|
return vat_outputs
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = get_args()
|
|
class_names = ['others', 'title', 'key', 'value', 'header']
|
|
predict_mode = {
|
|
'normal': 0,
|
|
'full_tokens': 1,
|
|
'sliding': 2,
|
|
'document': 3
|
|
}
|
|
predictor, processor = load_engine(args.exp_dir, class_names, args.mode)
|
|
create_dir(args.save_dir)
|
|
image_path = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
|
|
save_dir = "/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/visualize/test"
|
|
predict_image(image_path, save_dir, predictor, processor) |