sbt-idp/cope2n-ai-fi/common/utils/ocr_yolox.py

85 lines
2.8 KiB
Python
Executable File

import numpy as np
from .utils import get_crop_img_and_bbox
from sdsvtr import StandaloneSATRNRunner
from sdsvtd import StandaloneYOLOXRunner
import urllib
import cv2
import logging
import logging.config
from utils.logging.logging import LOGGER_CONFIG
# Load the logging configuration
logging.config.dictConfig(LOGGER_CONFIG)
# Get the logger
logger = logging.getLogger(__name__)
class YoloX:
def __init__(self, checkpoint):
self.model = StandaloneYOLOXRunner(checkpoint, device = "cuda:0")
def inference(self, img=None):
runner = self.model
return runner(img)
class Classifier_SATRN:
def __init__(self, checkpoint):
self.model = StandaloneSATRNRunner(checkpoint, return_confident=True, device = "cuda:0")
def inference(self, numpy_image):
model_inference = self.model
result = model_inference(numpy_image)
preds_str = result[0]
confidence = result[1]
return preds_str, confidence
class OcrEngineForYoloX_Invoice:
def __init__(self, det_ckpt, cls_ckpt):
self.det = YoloX(det_ckpt)
self.cls = Classifier_SATRN(cls_ckpt)
def run_image(self, img):
pred_det = self.det.inference(img)
pred_det = pred_det[0]
pred_det = sorted(pred_det, key=lambda box: [box[1], box[0]])
if len(pred_det) == 0:
return [], []
else:
bboxes = np.vstack(pred_det)
lbboxes = []
lcropped_img = []
assert len(bboxes) != 0, f"No bbox found in image, skipped"
for bbox in bboxes:
try:
crop_img, bbox_ = get_crop_img_and_bbox(img, bbox, extend=True)
lbboxes.append(bbox_)
lcropped_img.append(crop_img)
except AssertionError as e:
logger.info(e)
logger.info(f"[ERROR]: Skipping invalid bbox in image")
lwords, _ = self.cls.inference(lcropped_img)
return lbboxes, lwords
class OcrEngineForYoloX_ID_Driving:
def __init__(self, det_ckpt, cls_ckpt):
self.det = YoloX(det_ckpt)
self.cls = Classifier_SATRN(cls_ckpt)
def run_image(self, img):
pred_det = self.det.inference(img)
bboxes = np.vstack(pred_det)
lbboxes = []
lcropped_img = []
assert len(bboxes) != 0, f"No bbox found in image, skipped"
for bbox in bboxes:
try:
crop_img, bbox_ = get_crop_img_and_bbox(img, bbox, extend=True)
lbboxes.append(bbox_)
lcropped_img.append(crop_img)
except AssertionError:
logger.info(f"[ERROR]: Skipping invalid bbox image in ")
lwords, _ = self.cls.inference(lcropped_img)
return lbboxes, lwords