73 lines
3.3 KiB
73 lines
3.3 KiB
import imagesize
from PIL import Image
from pathlib import Path
from omegaconf import OmegaConf
from sdsvkvu.modules.predictor import KVUPredictor
from sdsvkvu.modules.preprocess import KVUProcessor, DocKVUProcessor, SBTProcessor
from sdsvkvu.modules.run_ocr import load_ocr_engine, process_img
from sdsvkvu.utils.utils import post_process_basic_ocr
from sdsvkvu.sources.utils import revert_scale_bbox, Timer
DEFAULT_SETTING_PATH = str(Path(__file__).parents[1]) + "/settings.yml"
class KVUEngine:
def __init__(self, setting_file: str = DEFAULT_SETTING_PATH, ocr_engine=None, **kwargs) -> None:
configs = OmegaConf.load(setting_file)
for key, param in kwargs.items(): # overwrite default settings by keyword arguments
if key not in configs:
raise ValueError("Invalid setting found in KVUEngine: ", key)
if isinstance(param, dict):
for k, v in param.items():
if k not in configs[key]:
raise ValueError("Invalid setting found in KVUEngine: ", key, k)
configs[key][k] = v
configs[key] = param
self.predictor = KVUPredictor(configs)
self._settings, tokenizer_layoutxlm, feature_extractor = self.predictor.get_process_configs()
mode = self._settings.mode
if mode in (0, 1, 2):
self.processor = KVUProcessor(tokenizer_layoutxlm=tokenizer_layoutxlm,
elif mode == 3:
self.processor = DocKVUProcessor(tokenizer_layoutxlm=tokenizer_layoutxlm,
elif mode == 4:
self.processor = SBTProcessor(tokenizer_layoutxlm=tokenizer_layoutxlm,
raise ValueError(f'[ERROR] Inferencing mode of {mode} is not supported')
if ocr_engine is None:
print("[INFO] Load internal OCR Engine")
configs.ocr_engine.device = configs.device
self.ocr_engine = load_ocr_engine(configs.ocr_engine)
print("[INFO] Load external OCR Engine")
self.ocr_engine = ocr_engine
def predict(self, img_path):
lbboxes, lwords, image = process_img(img_path, self.ocr_engine)
lwords = post_process_basic_ocr(lwords)
if len(lbboxes) == 0:
print("[WARNING] Empty document")
return image, [[]], [[]], [[]], [[]]
height, width, _ = image.shape
image = Image.fromarray(image)
inputs = self.processor(lbboxes, lwords, image, width=width, height=height)
with Timer("kvu"):
lbbox, lwords, pr_class_words, pr_relations = self.predictor.predict(inputs)
for i in range(len(lbbox)):
lbbox[i] = [revert_scale_bbox(bb, width=width, height=height) for bb in lbbox[i]]
return image, lbbox, lwords, pr_class_words, pr_relations