sbt-idp/cope2n-ai-fi/common/crop_location.py

173 lines
5.0 KiB
Python
Raw Normal View History

2023-11-30 11:22:16 +00:00
from mmdet.apis import inference_detector, init_detector
import cv2
import numpy as np
import urllib
def get_center(box):
xmin, ymin, xmax, ymax = box
x_center = int((xmin + xmax) / 2)
y_center = int((ymin + ymax) / 2)
return [x_center, y_center]
def cal_euclidean_dist(p1, p2):
return np.linalg.norm(p1 - p2)
def bbox_to_four_poinst(bbox):
"""convert one bouding box to 4 corner poinst
Args:
bbox (_type_): _description_
"""
xmin, ymin, xmax, ymax = bbox
poinst = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
return poinst
def find_closest_point(src_point, point_list):
"""
Args:
point (list): point format xy
point_list (list[list]): list of point xy
"""
point_list = np.array(point_list)
dist_list = np.array(
cal_euclidean_dist(src_point, target_point) for target_point in point_list
)
index_closest_point = np.argmin(dist_list)
return index_closest_point
def crop_align_card(img_src, corner_box_list):
"""Dewarp image based on four courners
Args:
corner_list (list): four points of corners
"""
img = img_src.copy()
if isinstance(corner_box_list[0], list):
poinst = [get_center(box) for box in corner_box_list]
else:
# print(corner_box_list)
xmin, ymin, xmax, ymax = corner_box_list
poinst = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
return dewarp(img, poinst)
def dewarp(image, poinst):
if isinstance(poinst, list):
poinst = np.array(poinst, dtype="float32")
(tl, tr, br, bl) = poinst
widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
maxWidth = max(int(widthA), int(widthB))
heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
maxHeight = max(int(heightA), int(heightB))
dst = np.array(
[[0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]],
dtype="float32",
)
M = cv2.getPerspectiveTransform(poinst, dst)
warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
return warped
class MdetPredictor:
def __init__(self, config: str, checkpoint: str, device: str = "cpu"):
self.model = init_detector(config, checkpoint, device=device)
self.class_names = self.model.CLASSES
def infer(self, image, threshold=0.2):
bbox_result = inference_detector(self.model, image)
bboxes = np.vstack(bbox_result)
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(bbox_result)
]
labels = np.concatenate(labels)
res_bboxes = []
res_labels = []
for idx, box in enumerate(bboxes):
score = box[-1]
if score >= threshold:
label = labels[idx]
res_bboxes.append(box.tolist()[:4])
res_labels.append(self.class_names[label])
return res_bboxes, res_labels
class ImageTransformer:
def __init__(self, config: str, checkpoint: str, device: str = "cpu"):
self.corner_detect_model = MdetPredictor(config, checkpoint, device)
def __call__(self, image, threshold=0.2):
"""
Args:
image (np.ndarray): BGR image
"""
corner_result = self.corner_detect_model.infer(image)
corners_dict = self.__extract_corners(corner_result)
card_image = self.__crop_image_based_on_corners(image, corners_dict)
return card_image
def __extract_corners(self, corner_result):
bboxes, labels = corner_result
# convert bbox to int
bboxes = [[int(x) for x in box] for box in bboxes]
output = {k: bboxes[labels.index(k)] for k in labels}
# print(output)
return output
def __crop_image_based_on_corners(self, image, corners_dict):
"""
Args:
corners_dict (_type_): _description_
"""
if "card" in corners_dict.keys():
if len(corners_dict.keys()) == 5:
points = [
corners_dict["top_left"],
corners_dict["top_right"],
corners_dict["bottom_right"],
corners_dict["bottom_left"],
]
else:
points = corners_dict["card"]
card_image = crop_align_card(image, points)
else:
card_image = None
return card_image
def crop_location(image_url):
transform_module = ImageTransformer(
config="./models/Kie_AHung/yolox_s_8x8_300e_idcard5_coco.py",
checkpoint="./models/Kie_AHung/best_bbox_mAP_epoch_100.pth",
device="cuda:0",
)
req = urllib.request.urlopen(image_url)
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
img = cv2.imdecode(arr, -1)
card_image = transform_module(img)
if card_image is not None:
return card_image
else:
return img