73 lines
3.3 KiB
Python
73 lines
3.3 KiB
Python
|
import imagesize
|
||
|
from PIL import Image
|
||
|
from pathlib import Path
|
||
|
from omegaconf import OmegaConf
|
||
|
|
||
|
from sdsvkvu.modules.predictor import KVUPredictor
|
||
|
from sdsvkvu.modules.preprocess import KVUProcessor, DocKVUProcessor, SBTProcessor
|
||
|
from sdsvkvu.modules.run_ocr import load_ocr_engine, process_img
|
||
|
from sdsvkvu.utils.utils import post_process_basic_ocr
|
||
|
from sdsvkvu.sources.utils import revert_scale_bbox, Timer
|
||
|
|
||
|
DEFAULT_SETTING_PATH = str(Path(__file__).parents[1]) + "/settings.yml"
|
||
|
|
||
|
|
||
|
class KVUEngine:
|
||
|
def __init__(self, setting_file: str = DEFAULT_SETTING_PATH, ocr_engine=None, **kwargs) -> None:
|
||
|
configs = OmegaConf.load(setting_file)
|
||
|
for key, param in kwargs.items(): # overwrite default settings by keyword arguments
|
||
|
if key not in configs:
|
||
|
raise ValueError("Invalid setting found in KVUEngine: ", key)
|
||
|
if isinstance(param, dict):
|
||
|
for k, v in param.items():
|
||
|
if k not in configs[key]:
|
||
|
raise ValueError("Invalid setting found in KVUEngine: ", key, k)
|
||
|
configs[key][k] = v
|
||
|
else:
|
||
|
configs[key] = param
|
||
|
|
||
|
self.predictor = KVUPredictor(configs)
|
||
|
self._settings, tokenizer_layoutxlm, feature_extractor = self.predictor.get_process_configs()
|
||
|
mode = self._settings.mode
|
||
|
if mode in (0, 1, 2):
|
||
|
self.processor = KVUProcessor(tokenizer_layoutxlm=tokenizer_layoutxlm,
|
||
|
feature_extractor=feature_extractor,
|
||
|
**self._settings)
|
||
|
elif mode == 3:
|
||
|
self.processor = DocKVUProcessor(tokenizer_layoutxlm=tokenizer_layoutxlm,
|
||
|
feature_extractor=feature_extractor,
|
||
|
**self._settings)
|
||
|
elif mode == 4:
|
||
|
self.processor = SBTProcessor(tokenizer_layoutxlm=tokenizer_layoutxlm,
|
||
|
feature_extractor=feature_extractor,
|
||
|
**self._settings)
|
||
|
else:
|
||
|
raise ValueError(f'[ERROR] Inferencing mode of {mode} is not supported')
|
||
|
|
||
|
if ocr_engine is None:
|
||
|
print("[INFO] Load internal OCR Engine")
|
||
|
configs.ocr_engine.device = configs.device
|
||
|
self.ocr_engine = load_ocr_engine(configs.ocr_engine)
|
||
|
else:
|
||
|
print("[INFO] Load external OCR Engine")
|
||
|
self.ocr_engine = ocr_engine
|
||
|
|
||
|
def predict(self, img_path):
|
||
|
lbboxes, lwords, image = process_img(img_path, self.ocr_engine)
|
||
|
lwords = post_process_basic_ocr(lwords)
|
||
|
|
||
|
if len(lbboxes) == 0:
|
||
|
print("[WARNING] Empty document")
|
||
|
return image, [[]], [[]], [[]], [[]]
|
||
|
|
||
|
height, width, _ = image.shape
|
||
|
image = Image.fromarray(image)
|
||
|
|
||
|
inputs = self.processor(lbboxes, lwords, image, width=width, height=height)
|
||
|
|
||
|
with Timer("kvu"):
|
||
|
lbbox, lwords, pr_class_words, pr_relations = self.predictor.predict(inputs)
|
||
|
|
||
|
for i in range(len(lbbox)):
|
||
|
lbbox[i] = [revert_scale_bbox(bb, width=width, height=height) for bb in lbbox[i]]
|
||
|
return image, lbbox, lwords, pr_class_words, pr_relations
|