import os import cv2 as cv import glob from xml.dom.expatbuilder import parseString from lxml.etree import Element, tostring, SubElement import tqdm from common.utils.global_variables import * def boxes_to_xml(boxes_lst, xml_pth, img_pth=""): """_summary_ Args: boxes_lst (_type_): _description_ xml_pth (_type_): _description_ img_pth (str, optional): _description_. Defaults to ''. """ node_root = Element("annotation") node_folder = SubElement(node_root, "folder") node_folder.text = "images" node_filename = SubElement(node_root, "filename") node_filename.text = os.path.basename(img_pth) # insert size of image if img_pth == "": width, height = 0, 0 else: img = cv.imread(img_pth) new_path = xml_pth[:-3] + "jpg" cv.imwrite(new_path, img) width, height = img.shape[:2] node_size = SubElement(node_root, "size") node_width = SubElement(node_size, "width") node_width.text = str(width) node_height = SubElement(node_size, "height") node_height.text = str(height) node_depth = SubElement(node_size, "depth") node_depth.text = "3" node_segmented = SubElement(node_root, "segmented") node_segmented.text = "0" for box in boxes_lst: left, top, right, bottom = box.xmin, box.ymin, box.xmax, box.ymax left, top, right, bottom = str(left), str(top), str(right), str(bottom) label = box.label if label == None: label = "" node_object = SubElement(node_root, "object") node_name = SubElement(node_object, "name") node_name.text = label node_pose = SubElement(node_object, "pose") node_pose.text = "Unspecified" node_truncated = SubElement(node_object, "truncated") node_truncated.text = "0" node_difficult = SubElement(node_object, "difficult") node_difficult.text = "0" # insert bounding box node_bndbox = SubElement(node_object, "bndbox") node_xmin = SubElement(node_bndbox, "xmin") node_xmin.text = left node_ymin = SubElement(node_bndbox, "ymin") node_ymin.text = top node_xmax = SubElement(node_bndbox, "xmax") node_xmax.text = right node_ymax = SubElement(node_bndbox, "ymax") node_ymax.text = bottom xml = tostring(node_root, pretty_print=True) dom = parseString(xml) with open(xml_pth, "w+", encoding="utf-8") as f: dom.writexml(f, indent="\t", addindent="\t", encoding="utf-8") class Box: def __init__(self): self.xmax = 0 self.ymax = 0 self.xmin = 0 self.ymin = 0 self.label = "" self.kie_label = "" def check_iou(box1: Box, box2: Box, threshold=0.9): area1 = (box1.xmax - box1.xmin) * (box1.ymax - box1.ymin) area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin) xmin_intersect = max(box1.xmin, box2.xmin) ymin_intersect = max(box1.ymin, box2.ymin) xmax_intersect = min(box1.xmax, box2.xmax) ymax_intersect = min(box1.ymax, box2.ymax) if xmax_intersect < xmin_intersect or ymax_intersect < ymin_intersect: area_intersect = 0 else: area_intersect = (xmax_intersect - xmin_intersect) * ( ymax_intersect * ymin_intersect ) union = area1 + area2 - area_intersect print(union) iou = area_intersect / area1 if iou > threshold: return True return False DATA_ROOT = "/home/sds/hoangmd/TokenClassification/images/infer" PSEUDO_LABEL = "/home/sds/hoangmd/TokenClassification/infer/" list_files = glob.glob(PSEUDO_LABEL + "*.txt") for file in tqdm.tqdm(list_files): xml_path = os.path.join("generated_label/", os.path.basename(file)[:-3] + "xml") img_path = os.path.join(DATA_ROOT, os.path.basename(file)[:-3] + "jpg") if not os.path.exists(img_path): continue f = open(file, "r", encoding="utf-8") boxes = [] for line in f.readlines(): xmin, ymin, xmax, ymax, label = line.split("\t") label = label[:-1] box = Box() box.xmin = int(float(xmin)) # left , top , right, bottom box.ymin = int(float(ymin)) box.xmax = int(float(xmax)) box.ymax = int(float(ymax)) box.label = label boxes.append(box) f.close() boxes.sort(key=lambda x: [x.ymin, x.xmin]) boxes_to_xml(boxes, xml_path, img_path)