from mmdet.apis import inference_detector, init_detector
import cv2
import numpy as np
import urllib

def get_center(box):
    xmin, ymin, xmax, ymax = box
    x_center = int((xmin + xmax) / 2)
    y_center = int((ymin + ymax) / 2)
    return [x_center, y_center]


def cal_euclidean_dist(p1, p2):
    return np.linalg.norm(p1 - p2)


def bbox_to_four_poinst(bbox):
    """convert one bouding box to 4 corner poinst

    Args:
        bbox (_type_): _description_
    """
    xmin, ymin, xmax, ymax = bbox

    poinst = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
    return poinst


def find_closest_point(src_point, point_list):
    """

    Args:
        point (list): point format xy
        point_list (list[list]): list of point xy
    """

    point_list = np.array(point_list)
    dist_list = np.array(
        cal_euclidean_dist(src_point, target_point) for target_point in point_list
    )

    index_closest_point = np.argmin(dist_list)
    return index_closest_point


def crop_align_card(img_src, corner_box_list):
    """Dewarp image based on four courners

    Args:
        corner_list (list): four points of corners
    """
    img = img_src.copy()
    if isinstance(corner_box_list[0], list):
        poinst = [get_center(box) for box in corner_box_list]
    else:
        # print(corner_box_list)
        xmin, ymin, xmax, ymax = corner_box_list
        poinst = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]

    return dewarp(img, poinst)


def dewarp(image, poinst):
    if isinstance(poinst, list):
        poinst = np.array(poinst, dtype="float32")
    (tl, tr, br, bl) = poinst
    widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
    widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
    maxWidth = max(int(widthA), int(widthB))
    heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
    heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
    maxHeight = max(int(heightA), int(heightB))

    dst = np.array(
        [[0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]],
        dtype="float32",
    )

    M = cv2.getPerspectiveTransform(poinst, dst)
    warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
    return warped


class MdetPredictor:
    def __init__(self, config: str, checkpoint: str, device: str = "cpu"):
        self.model = init_detector(config, checkpoint, device=device)
        self.class_names = self.model.CLASSES

    def infer(self, image, threshold=0.2):
        bbox_result = inference_detector(self.model, image)

        bboxes = np.vstack(bbox_result)
        labels = [
            np.full(bbox.shape[0], i, dtype=np.int32)
            for i, bbox in enumerate(bbox_result)
        ]
        labels = np.concatenate(labels)

        res_bboxes = []
        res_labels = []
        for idx, box in enumerate(bboxes):
            score = box[-1]
            if score >= threshold:
                label = labels[idx]
                res_bboxes.append(box.tolist()[:4])
                res_labels.append(self.class_names[label])

        return res_bboxes, res_labels


class ImageTransformer:
    def __init__(self, config: str, checkpoint: str, device: str = "cpu"):
        self.corner_detect_model = MdetPredictor(config, checkpoint, device)

    def __call__(self, image, threshold=0.2):
        """

        Args:
            image (np.ndarray): BGR image
        """
        corner_result = self.corner_detect_model.infer(image)
        corners_dict = self.__extract_corners(corner_result)
        card_image = self.__crop_image_based_on_corners(image, corners_dict)

        return card_image

    def __extract_corners(self, corner_result):

        bboxes, labels = corner_result
        # convert bbox to int
        bboxes = [[int(x) for x in box] for box in bboxes]
        output = {k: bboxes[labels.index(k)] for k in labels}
        # print(output)
        return output

    def __crop_image_based_on_corners(self, image, corners_dict):
        """

        Args:
            corners_dict (_type_): _description_
        """
        if "card" in corners_dict.keys():
            if len(corners_dict.keys()) == 5:
                points = [
                    corners_dict["top_left"],
                    corners_dict["top_right"],
                    corners_dict["bottom_right"],
                    corners_dict["bottom_left"],
                ]
            else:
                points = corners_dict["card"]
            card_image = crop_align_card(image, points)
        else:
            card_image = None

        return card_image


def crop_location(image_url):
    transform_module = ImageTransformer(
        config="./models/Kie_AHung/yolox_s_8x8_300e_idcard5_coco.py",
        checkpoint="./models/Kie_AHung/best_bbox_mAP_epoch_100.pth",
        device="cuda:0",
    )
    req = urllib.request.urlopen(image_url)
    arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
    img = cv2.imdecode(arr, -1)
    card_image = transform_module(img)
    if card_image is not None:
        return card_image
    else:
        return img