import os
import cv2 as cv
import glob
from xml.dom.expatbuilder import parseString
from lxml.etree import Element, tostring, SubElement
import tqdm
from common.utils.global_variables import *


def boxes_to_xml(boxes_lst, xml_pth, img_pth=""):
    """_summary_

    Args:
        boxes_lst (_type_): _description_
        xml_pth (_type_): _description_
        img_pth (str, optional): _description_. Defaults to ''.
    """
    node_root = Element("annotation")

    node_folder = SubElement(node_root, "folder")
    node_folder.text = "images"

    node_filename = SubElement(node_root, "filename")
    node_filename.text = os.path.basename(img_pth)

    # insert size of image
    if img_pth == "":
        width, height = 0, 0
    else:
        img = cv.imread(img_pth)
        new_path = xml_pth[:-3] + "jpg"
        cv.imwrite(new_path, img)
        width, height = img.shape[:2]

    node_size = SubElement(node_root, "size")

    node_width = SubElement(node_size, "width")
    node_width.text = str(width)

    node_height = SubElement(node_size, "height")
    node_height.text = str(height)

    node_depth = SubElement(node_size, "depth")
    node_depth.text = "3"

    node_segmented = SubElement(node_root, "segmented")
    node_segmented.text = "0"

    for box in boxes_lst:
        left, top, right, bottom = box.xmin, box.ymin, box.xmax, box.ymax
        left, top, right, bottom = str(left), str(top), str(right), str(bottom)
        label = box.label
        if label == None:
            label = ""

        node_object = SubElement(node_root, "object")
        node_name = SubElement(node_object, "name")
        node_name.text = label

        node_pose = SubElement(node_object, "pose")
        node_pose.text = "Unspecified"
        node_truncated = SubElement(node_object, "truncated")
        node_truncated.text = "0"
        node_difficult = SubElement(node_object, "difficult")
        node_difficult.text = "0"

        # insert bounding box
        node_bndbox = SubElement(node_object, "bndbox")
        node_xmin = SubElement(node_bndbox, "xmin")
        node_xmin.text = left
        node_ymin = SubElement(node_bndbox, "ymin")
        node_ymin.text = top
        node_xmax = SubElement(node_bndbox, "xmax")
        node_xmax.text = right
        node_ymax = SubElement(node_bndbox, "ymax")
        node_ymax.text = bottom

    xml = tostring(node_root, pretty_print=True)
    dom = parseString(xml)
    with open(xml_pth, "w+", encoding="utf-8") as f:
        dom.writexml(f, indent="\t", addindent="\t", encoding="utf-8")


class Box:
    def __init__(self):
        self.xmax = 0
        self.ymax = 0
        self.xmin = 0
        self.ymin = 0
        self.label = ""
        self.kie_label = ""


def check_iou(box1: Box, box2: Box, threshold=0.9):
    area1 = (box1.xmax - box1.xmin) * (box1.ymax - box1.ymin)
    area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin)
    xmin_intersect = max(box1.xmin, box2.xmin)
    ymin_intersect = max(box1.ymin, box2.ymin)
    xmax_intersect = min(box1.xmax, box2.xmax)
    ymax_intersect = min(box1.ymax, box2.ymax)
    if xmax_intersect < xmin_intersect or ymax_intersect < ymin_intersect:
        area_intersect = 0
    else:
        area_intersect = (xmax_intersect - xmin_intersect) * (
            ymax_intersect * ymin_intersect
        )
    union = area1 + area2 - area_intersect
    print(union)
    iou = area_intersect / area1
    if iou > threshold:
        return True
    return False


DATA_ROOT = "/home/sds/hoangmd/TokenClassification/images/infer"
PSEUDO_LABEL = "/home/sds/hoangmd/TokenClassification/infer/"
list_files = glob.glob(PSEUDO_LABEL + "*.txt")

for file in tqdm.tqdm(list_files):
    xml_path = os.path.join("generated_label/", os.path.basename(file)[:-3] + "xml")
    img_path = os.path.join(DATA_ROOT, os.path.basename(file)[:-3] + "jpg")
    if not os.path.exists(img_path):
        continue
    f = open(file, "r", encoding="utf-8")
    boxes = []
    for line in f.readlines():
        xmin, ymin, xmax, ymax, label = line.split("\t")
        label = label[:-1]
        box = Box()
        box.xmin = int(float(xmin))  # left , top , right, bottom
        box.ymin = int(float(ymin))
        box.xmax = int(float(xmax))
        box.ymax = int(float(ymax))
        box.label = label
        boxes.append(box)
    f.close()
    boxes.sort(key=lambda x: [x.ymin, x.xmin])

    boxes_to_xml(boxes, xml_path, img_path)