from mmdet.apis import inference_detector, init_detector import cv2 import numpy as np import urllib def get_center(box): xmin, ymin, xmax, ymax = box x_center = int((xmin + xmax) / 2) y_center = int((ymin + ymax) / 2) return [x_center, y_center] def cal_euclidean_dist(p1, p2): return np.linalg.norm(p1 - p2) def bbox_to_four_poinst(bbox): """convert one bouding box to 4 corner poinst Args: bbox (_type_): _description_ """ xmin, ymin, xmax, ymax = bbox poinst = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]] return poinst def find_closest_point(src_point, point_list): """ Args: point (list): point format xy point_list (list[list]): list of point xy """ point_list = np.array(point_list) dist_list = np.array( cal_euclidean_dist(src_point, target_point) for target_point in point_list ) index_closest_point = np.argmin(dist_list) return index_closest_point def crop_align_card(img_src, corner_box_list): """Dewarp image based on four courners Args: corner_list (list): four points of corners """ img = img_src.copy() if isinstance(corner_box_list[0], list): poinst = [get_center(box) for box in corner_box_list] else: # print(corner_box_list) xmin, ymin, xmax, ymax = corner_box_list poinst = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]] return dewarp(img, poinst) def dewarp(image, poinst): if isinstance(poinst, list): poinst = np.array(poinst, dtype="float32") (tl, tr, br, bl) = poinst widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2)) widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2)) maxWidth = max(int(widthA), int(widthB)) heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2)) heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2)) maxHeight = max(int(heightA), int(heightB)) dst = np.array( [[0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]], dtype="float32", ) M = cv2.getPerspectiveTransform(poinst, dst) warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight)) return warped class MdetPredictor: def __init__(self, config: str, checkpoint: str, device: str = "cpu"): self.model = init_detector(config, checkpoint, device=device) self.class_names = self.model.CLASSES def infer(self, image, threshold=0.2): bbox_result = inference_detector(self.model, image) bboxes = np.vstack(bbox_result) labels = [ np.full(bbox.shape[0], i, dtype=np.int32) for i, bbox in enumerate(bbox_result) ] labels = np.concatenate(labels) res_bboxes = [] res_labels = [] for idx, box in enumerate(bboxes): score = box[-1] if score >= threshold: label = labels[idx] res_bboxes.append(box.tolist()[:4]) res_labels.append(self.class_names[label]) return res_bboxes, res_labels class ImageTransformer: def __init__(self, config: str, checkpoint: str, device: str = "cpu"): self.corner_detect_model = MdetPredictor(config, checkpoint, device) def __call__(self, image, threshold=0.2): """ Args: image (np.ndarray): BGR image """ corner_result = self.corner_detect_model.infer(image) corners_dict = self.__extract_corners(corner_result) card_image = self.__crop_image_based_on_corners(image, corners_dict) return card_image def __extract_corners(self, corner_result): bboxes, labels = corner_result # convert bbox to int bboxes = [[int(x) for x in box] for box in bboxes] output = {k: bboxes[labels.index(k)] for k in labels} # print(output) return output def __crop_image_based_on_corners(self, image, corners_dict): """ Args: corners_dict (_type_): _description_ """ if "card" in corners_dict.keys(): if len(corners_dict.keys()) == 5: points = [ corners_dict["top_left"], corners_dict["top_right"], corners_dict["bottom_right"], corners_dict["bottom_left"], ] else: points = corners_dict["card"] card_image = crop_align_card(image, points) else: card_image = None return card_image def crop_location(image_url): transform_module = ImageTransformer( config="./models/Kie_AHung/yolox_s_8x8_300e_idcard5_coco.py", checkpoint="./models/Kie_AHung/best_bbox_mAP_epoch_100.pth", device="cuda:0", ) req = urllib.request.urlopen(image_url) arr = np.asarray(bytearray(req.read()), dtype=np.uint8) img = cv2.imdecode(arr, -1) card_image = transform_module(img) if card_image is not None: return card_image else: return img