85 lines
2.8 KiB
Python
Executable File
85 lines
2.8 KiB
Python
Executable File
import numpy as np
|
|
from .utils import get_crop_img_and_bbox
|
|
from sdsvtr import StandaloneSATRNRunner
|
|
from sdsvtd import StandaloneYOLOXRunner
|
|
import urllib
|
|
import cv2
|
|
import logging
|
|
import logging.config
|
|
from utils.logging.logging import LOGGER_CONFIG
|
|
# Load the logging configuration
|
|
logging.config.dictConfig(LOGGER_CONFIG)
|
|
# Get the logger
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class YoloX:
|
|
def __init__(self, checkpoint):
|
|
self.model = StandaloneYOLOXRunner(checkpoint, device = "cuda:0")
|
|
|
|
def inference(self, img=None):
|
|
runner = self.model
|
|
return runner(img)
|
|
|
|
|
|
class Classifier_SATRN:
|
|
def __init__(self, checkpoint):
|
|
self.model = StandaloneSATRNRunner(checkpoint, return_confident=True, device = "cuda:0")
|
|
|
|
def inference(self, numpy_image):
|
|
model_inference = self.model
|
|
result = model_inference(numpy_image)
|
|
preds_str = result[0]
|
|
confidence = result[1]
|
|
return preds_str, confidence
|
|
|
|
class OcrEngineForYoloX_Invoice:
|
|
def __init__(self, det_ckpt, cls_ckpt):
|
|
self.det = YoloX(det_ckpt)
|
|
self.cls = Classifier_SATRN(cls_ckpt)
|
|
|
|
def run_image(self, img):
|
|
|
|
pred_det = self.det.inference(img)
|
|
pred_det = pred_det[0]
|
|
|
|
pred_det = sorted(pred_det, key=lambda box: [box[1], box[0]])
|
|
if len(pred_det) == 0:
|
|
return [], []
|
|
else:
|
|
bboxes = np.vstack(pred_det)
|
|
lbboxes = []
|
|
lcropped_img = []
|
|
assert len(bboxes) != 0, f"No bbox found in image, skipped"
|
|
for bbox in bboxes:
|
|
try:
|
|
crop_img, bbox_ = get_crop_img_and_bbox(img, bbox, extend=True)
|
|
lbboxes.append(bbox_)
|
|
lcropped_img.append(crop_img)
|
|
except AssertionError as e:
|
|
logger.info(e)
|
|
logger.info(f"[ERROR]: Skipping invalid bbox in image")
|
|
lwords, _ = self.cls.inference(lcropped_img)
|
|
return lbboxes, lwords
|
|
|
|
class OcrEngineForYoloX_ID_Driving:
|
|
def __init__(self, det_ckpt, cls_ckpt):
|
|
self.det = YoloX(det_ckpt)
|
|
self.cls = Classifier_SATRN(cls_ckpt)
|
|
|
|
def run_image(self, img):
|
|
pred_det = self.det.inference(img)
|
|
bboxes = np.vstack(pred_det)
|
|
lbboxes = []
|
|
lcropped_img = []
|
|
assert len(bboxes) != 0, f"No bbox found in image, skipped"
|
|
for bbox in bboxes:
|
|
try:
|
|
crop_img, bbox_ = get_crop_img_and_bbox(img, bbox, extend=True)
|
|
lbboxes.append(bbox_)
|
|
lcropped_img.append(crop_img)
|
|
except AssertionError:
|
|
logger.info(f"[ERROR]: Skipping invalid bbox image in ")
|
|
lwords, _ = self.cls.inference(lcropped_img)
|
|
return lbboxes, lwords
|