sbt-idp/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/sources/kvu.py
2023-11-30 18:22:16 +07:00

73 lines
3.3 KiB
Python

import imagesize
from PIL import Image
from pathlib import Path
from omegaconf import OmegaConf
from sdsvkvu.modules.predictor import KVUPredictor
from sdsvkvu.modules.preprocess import KVUProcessor, DocKVUProcessor, SBTProcessor
from sdsvkvu.modules.run_ocr import load_ocr_engine, process_img
from sdsvkvu.utils.utils import post_process_basic_ocr
from sdsvkvu.sources.utils import revert_scale_bbox, Timer
DEFAULT_SETTING_PATH = str(Path(__file__).parents[1]) + "/settings.yml"
class KVUEngine:
def __init__(self, setting_file: str = DEFAULT_SETTING_PATH, ocr_engine=None, **kwargs) -> None:
configs = OmegaConf.load(setting_file)
for key, param in kwargs.items(): # overwrite default settings by keyword arguments
if key not in configs:
raise ValueError("Invalid setting found in KVUEngine: ", key)
if isinstance(param, dict):
for k, v in param.items():
if k not in configs[key]:
raise ValueError("Invalid setting found in KVUEngine: ", key, k)
configs[key][k] = v
else:
configs[key] = param
self.predictor = KVUPredictor(configs)
self._settings, tokenizer_layoutxlm, feature_extractor = self.predictor.get_process_configs()
mode = self._settings.mode
if mode in (0, 1, 2):
self.processor = KVUProcessor(tokenizer_layoutxlm=tokenizer_layoutxlm,
feature_extractor=feature_extractor,
**self._settings)
elif mode == 3:
self.processor = DocKVUProcessor(tokenizer_layoutxlm=tokenizer_layoutxlm,
feature_extractor=feature_extractor,
**self._settings)
elif mode == 4:
self.processor = SBTProcessor(tokenizer_layoutxlm=tokenizer_layoutxlm,
feature_extractor=feature_extractor,
**self._settings)
else:
raise ValueError(f'[ERROR] Inferencing mode of {mode} is not supported')
if ocr_engine is None:
print("[INFO] Load internal OCR Engine")
configs.ocr_engine.device = configs.device
self.ocr_engine = load_ocr_engine(configs.ocr_engine)
else:
print("[INFO] Load external OCR Engine")
self.ocr_engine = ocr_engine
def predict(self, img_path):
lbboxes, lwords, image = process_img(img_path, self.ocr_engine)
lwords = post_process_basic_ocr(lwords)
if len(lbboxes) == 0:
print("[WARNING] Empty document")
return image, [[]], [[]], [[]], [[]]
height, width, _ = image.shape
image = Image.fromarray(image)
inputs = self.processor(lbboxes, lwords, image, width=width, height=height)
with Timer("kvu"):
lbbox, lwords, pr_class_words, pr_relations = self.predictor.predict(inputs)
for i in range(len(lbbox)):
lbbox[i] = [revert_scale_bbox(bb, width=width, height=height) for bb in lbbox[i]]
return image, lbbox, lwords, pr_class_words, pr_relations