import os import cv2 import copy import time import torch import math import numpy as np from typing import Callable from sdsvkvu.utils.post_processing import get_string, get_string_by_deduplicate_bbox, get_string_with_word2line # def get_colormap(): # return { # 'others': (0, 0, 255), # others: red # 'title': (0, 255, 255), # title: yellow # 'key': (255, 0, 0), # key: blue # 'value': (0, 255, 0), # value: green # 'header': (233, 197, 15), # header # 'group': (0, 128, 128), # group # 'relation': (0, 0, 255)# (128, 128, 128), # relation # } class Timer: def __init__(self, name: str) -> None: self.name = name def __enter__(self): self.start_time = time.perf_counter() return self def __exit__(self, func: Callable, *args): self.end_time = time.perf_counter() self.elapsed_time = self.end_time - self.start_time print(f"[INFO]: {self.name} took : {self.elapsed_time:.6f} seconds") def get_colormap(): return { "others": (0, 0, 255), # others: red "title": (0, 255, 255), # title: yellow "key": (255, 0, 0), # key: blue "value": (0, 255, 0), # value: green "header": (233, 197, 15), # header "group": (0, 128, 128), # group "relation": (0, 0, 255), # (128, 128, 128), # relation # "others": (187, 125, 250), # pink "seller": (183, 50, 255), # bold pink "date_key": (128, 51, 115), # orange "date_value": (55, 250, 250), # yellow "product_name": (245, 61, 61), # blue "product_code": (233, 197, 17), # header "quantity": (102, 255, 102), # green "sn_key": (179, 134, 89), "sn_value": (51, 153, 204), "invoice_number_key": (40, 90, 144), "invoice_number_value": (162, 239, 204), "sold_key": (74, 180, 150), "sold_value": (14, 184, 53), "voucher": (39, 86, 103), "website": (207, 19, 85), "hotline": (153, 224, 56), # "group": (0, 128, 128), # brown # "relation": (0, 0, 255), # (128, 128, 128), # red } def convert_image(image): exif = image._getexif() orientation = None if exif is not None: orientation = exif.get(0x0112) # Convert the PIL image to OpenCV format image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # Rotate the image in OpenCV if necessary if orientation == 3: image = cv2.rotate(image, cv2.ROTATE_180) elif orientation == 6: image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE) elif orientation == 8: image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE) else: image = np.asarray(image) if len(image.shape) == 2: image = np.repeat(image[:, :, np.newaxis], 3, axis=2) assert len(image.shape) == 3 return image, orientation def visualize(image, bbox, pr_class_words, pr_relations, color_map, labels=['others', 'title', 'key', 'value', 'header'], thickness=1): # image, orientation = convert_image(image) # if orientation is not None and orientation == 6: # width, height, _ = image.shape # else: # height, width, _ = image.shape if len(pr_class_words) > 0: id2label = {k: labels[k] for k in range(len(labels))} for lb, groups in enumerate(pr_class_words): if lb == 0: continue for group_id, group in enumerate(groups): for i, word_id in enumerate(group): # x0, y0, x1, y1 = revert_scale_bbox(bbox[word_id], width, height) x0, y0, x1, y1 = bbox[word_id] cv2.rectangle(image, (x0, y0), (x1, y1), color=color_map[id2label[lb]], thickness=thickness) if i == 0: x_center0, y_center0 = int((x0+x1)/2), int((y0+y1)/2) else: x_center1, y_center1 = int((x0+x1)/2), int((y0+y1)/2) cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['group'], thickness=thickness) x_center0, y_center0 = x_center1, y_center1 if len(pr_relations) > 0: for pair in pr_relations: # xyxy0 = revert_scale_bbox(bbox[pair[0]], width, height) # xyxy1 = revert_scale_bbox(bbox[pair[1]], width, height) xyxy0 = bbox[pair[0]] xyxy1 = bbox[pair[1]] x_center0, y_center0 = int((xyxy0[0] + xyxy0[2])/2), int((xyxy0[1] + xyxy0[3])/2) x_center1, y_center1 = int((xyxy1[0] + xyxy1[2])/2), int((xyxy1[1] + xyxy1[3])/2) cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['relation'], thickness=thickness) return image def revert_scale_bbox(box, width, height): return [ int((box[0] / 1000) * width), int((box[1] / 1000) * height), int((box[2] / 1000) * width), int((box[3] / 1000) * height) ] def draw_kvu_outputs(image: np.ndarray, bbox: list, pr_class_words: list, pr_relations: list, class_names: list = ['others', 'title', 'key', 'value', 'header'], thickness: int = 1): color_map = get_colormap() image = visualize(image, bbox, pr_class_words, pr_relations, color_map, class_names, thickness) if (image.shape[2] == 2): image = cv2.cvtColor(image, cv2.COLOR_BGR5652BGR) return cv2.cvtColor(image, cv2.COLOR_RGB2BGR) def sliding_windows(elements: list, window_size: int, slice_interval: int) -> list: element_windows = [] if len(elements) > window_size: max_step = math.ceil((len(elements) - window_size)/slice_interval) for i in range(0, max_step + 1): if (i*slice_interval+window_size) >= len(elements): _window = copy.deepcopy(elements[i*slice_interval:]) else: _window = copy.deepcopy(elements[i*slice_interval: i*slice_interval+window_size]) element_windows.append(_window) return element_windows else: return [elements] def merged_token_embeddings(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor: start_pos = 1 end_pos = start_pos + lvalids[0] embedding_tokens = copy.deepcopy(lpatches[0][:, start_pos:end_pos, ...]) cls_token = copy.deepcopy(lpatches[0][:, :1, ...]) sep_token = copy.deepcopy(lpatches[0][:, -1:, ...]) for i in range(1, len(lpatches)): start_pos = 1 end_pos = start_pos + lvalids[i] overlap_gap = copy.deepcopy(loverlaps[i-1]) window = copy.deepcopy(lpatches[i][:, start_pos:end_pos, ...]) if overlap_gap != 0: prev_overlap = copy.deepcopy(embedding_tokens[:, -overlap_gap:, ...]) curr_overlap = copy.deepcopy(window[:, :overlap_gap, ...]) assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}" if average: avg_overlap = ( prev_overlap + curr_overlap ) / 2. embedding_tokens = torch.cat( [embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1 ) else: embedding_tokens = torch.cat( [embedding_tokens[:, :-overlap_gap, ...], curr_overlap, window[:, overlap_gap:, ...]], dim=1 ) else: embedding_tokens = torch.cat( [embedding_tokens, window], dim=1 ) return torch.cat([cls_token, embedding_tokens, sep_token], dim=1) def parse_initial_words(itc_label, box_first_token_mask, class_names): itc_label_np = itc_label.cpu().numpy() box_first_token_mask_np = box_first_token_mask.cpu().numpy() outputs = [[] for _ in range(len(class_names))] for token_idx, label in enumerate(itc_label_np): if box_first_token_mask_np[token_idx] and label != 0: outputs[label].append(token_idx) return outputs def parse_subsequent_words(stc_label, attention_mask, init_words, dummy_idx): max_connections = 50 valid_stc_label = stc_label * attention_mask.bool() valid_stc_label = valid_stc_label.cpu().numpy() stc_label_np = stc_label.cpu().numpy() valid_token_indices = np.where( (valid_stc_label != dummy_idx) * (valid_stc_label != 0) ) next_token_idx_dict = {} for token_idx in valid_token_indices[0]: next_token_idx_dict[stc_label_np[token_idx]] = token_idx outputs = [] for init_token_indices in init_words: sub_outputs = [] for init_token_idx in init_token_indices: cur_token_indices = [init_token_idx] for _ in range(max_connections): if cur_token_indices[-1] in next_token_idx_dict: if ( next_token_idx_dict[cur_token_indices[-1]] not in init_token_indices ): cur_token_indices.append( next_token_idx_dict[cur_token_indices[-1]] ) else: break else: break sub_outputs.append(tuple(cur_token_indices)) outputs.append(sub_outputs) return outputs def parse_relations(el_label, box_first_token_mask, dummy_idx): valid_el_labels = el_label * box_first_token_mask valid_el_labels = valid_el_labels.cpu().numpy() el_label_np = el_label.cpu().numpy() max_token = box_first_token_mask.shape[0] - 1 valid_token_indices = np.where( ((valid_el_labels != dummy_idx) * (valid_el_labels != 0)) ### ) link_map_tuples = [] for token_idx in valid_token_indices[0]: link_map_tuples.append((el_label_np[token_idx], token_idx)) return set(link_map_tuples) def get_pairs(json: list, rel_from: str, rel_to: str) -> dict: outputs = {} for pair in json: is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}} for element in pair: if element['class'] in (rel_from, rel_to): is_rel[element['class']]['status'] = 1 is_rel[element['class']]['value'] = element if all([v['status'] == 1 for _, v in is_rel.items()]): outputs[is_rel[rel_to]['value']['group_id']] = [is_rel[rel_from]['value']['group_id'], is_rel[rel_to]['value']['group_id']] return outputs def get_table_relations(json: list, header_key_pairs: dict, rel_from="key", rel_to="value") -> dict: list_keys = list(header_key_pairs.keys()) relations = {k: [] for k in list_keys} for pair in json: is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}} for element in pair: if element['class'] == rel_from and element['group_id'] in list_keys: is_rel[rel_from]['status'] = 1 is_rel[rel_from]['value'] = element if element['class'] == rel_to: is_rel[rel_to]['status'] = 1 is_rel[rel_to]['value'] = element if all([v['status'] == 1 for _, v in is_rel.items()]): relations[is_rel[rel_from]['value']['group_id']].append(is_rel[rel_to]['value']['group_id']) return relations def get_key2values_relations(key_value_pairs: dict): triple_linkings = {} for value_group_id, key_value_pair in key_value_pairs.items(): key_group_id = key_value_pair[0] if key_group_id not in list(triple_linkings.keys()): triple_linkings[key_group_id] = [] triple_linkings[key_group_id].append(value_group_id) return triple_linkings def get_wordgroup_bbox(lbbox: list, lword_ids: list) -> list: points = [lbbox[i] for i in lword_ids] x_min, y_min = min(points, key=lambda x: x[0])[0], min(points, key=lambda x: x[1])[1] x_max, y_max = max(points, key=lambda x: x[2])[2], max(points, key=lambda x: x[3])[3] return [x_min, y_min, x_max, y_max] def merged_token_to_wordgroup(class_words: list, lwords: list, lbbox: list, labels: list) -> dict: word_groups = {} id2class = {i: labels[i] for i in range(len(labels))} for class_id, lwgroups_in_class in enumerate(class_words): for ltokens_in_wgroup in lwgroups_in_class: group_id = ltokens_in_wgroup[0] ltokens_to_ltexts = [lwords[token] for token in ltokens_in_wgroup] ltokens_to_lbboxes = [lbbox[token] for token in ltokens_in_wgroup] # text_string = get_string(ltokens_to_ltexts) text_string = get_string_by_deduplicate_bbox(ltokens_to_ltexts, ltokens_to_lbboxes) # text_string = get_string_with_word2line(ltokens_to_ltexts, ltokens_to_lbboxes) group_bbox = get_wordgroup_bbox(lbbox, ltokens_in_wgroup) word_groups[group_id] = { 'group_id': group_id, 'text': text_string, 'class': id2class[class_id], 'tokens': ltokens_in_wgroup, 'bbox': group_bbox } return word_groups def verify_linking_id(word_groups: dict, linking_id: int) -> int: if linking_id not in list(word_groups): for wg_id, _word_group in word_groups.items(): if linking_id in _word_group['tokens']: return wg_id return linking_id def matched_wordgroup_relations(word_groups:dict, lrelations: list) -> list: outputs = [] for pair in lrelations: wg_from = verify_linking_id(word_groups, pair[0]) wg_to = verify_linking_id(word_groups, pair[1]) try: outputs.append([word_groups[wg_from], word_groups[wg_to]]) except Exception as e: print('Not valid pair:', wg_from, wg_to) return outputs def get_single_entity(word_groups: dict, lrelations: list, labels: list) -> list: # single_entity = {'title': [], 'key': [], 'value': [], 'header': []} single_entity = {lb: [] for lb in labels} list_linked_ids = [] for pair in lrelations: list_linked_ids.extend(pair) for word_group_id, word_group in word_groups.items(): if word_group_id not in list_linked_ids: single_entity[word_group['class']].append(word_group) return single_entity def export_kvu_outputs(lwords, lbbox, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']): word_groups = merged_token_to_wordgroup(class_words, lwords, lbbox, labels) linking_pairs = matched_wordgroup_relations(word_groups, lrelations) header_key = get_pairs(linking_pairs, rel_from='header', rel_to='key') # => {key_group_id: [header_group_id, key_group_id]} header_value = get_pairs(linking_pairs, rel_from='header', rel_to='value') # => {value_group_id: [header_group_id, value_group_id]} key_value = get_pairs(linking_pairs, rel_from='key', rel_to='value') # => {value_group_id: [key_group_id, value_group_id]} single_entity = get_single_entity(word_groups, lrelations, labels=labels) # table_relations = get_table_relations(linking_pairs, header_key) # => {key_group_id: [value_group_id1, value_groupid2, ...]} key2values_relations = get_key2values_relations(key_value) # => {key_group_id: [value_group_id1, value_groupid2, ...]} triplet_pairs = [] single_pairs = [] table = [] # print('key2values_relations', key2values_relations) for key_group_id, list_value_group_ids in key2values_relations.items(): if len(list_value_group_ids) == 0: continue elif len(list_value_group_ids) == 1: value_group_id = list_value_group_ids[0] single_pairs.append({word_groups[key_group_id]['text']: { 'id': value_group_id, 'class': "value", 'text': word_groups[value_group_id]['text'], 'bbox': word_groups[value_group_id]['bbox'], "key_bbox": word_groups[key_group_id]["bbox"], }}) else: item = [] for value_group_id in list_value_group_ids: if value_group_id not in header_value.keys(): header_name_for_value = "non-header" else: header_group_id = header_value[value_group_id][0] header_name_for_value = word_groups[header_group_id]['text'] item.append({ 'id': value_group_id, 'class': 'value', 'header': header_name_for_value, 'text': word_groups[value_group_id]['text'], 'bbox': word_groups[value_group_id]['bbox'], "key_bbox": word_groups[key_group_id]["bbox"], "header_bbox": word_groups[header_group_id]["bbox"] if header_group_id != -1 else [0, 0, 0, 0], }) if key_group_id not in list(header_key.keys()): triplet_pairs.append({ word_groups[key_group_id]['text']: item }) else: header_group_id = header_key[key_group_id][0] header_name_for_key = word_groups[header_group_id]['text'] item.append({ 'id': key_group_id, 'class': 'key', 'header': header_name_for_key, 'text': word_groups[key_group_id]['text'], 'bbox': word_groups[key_group_id]['bbox'], "key_bbox": word_groups[key_group_id]["bbox"], "header_bbox": word_groups[header_group_id]["bbox"], }) table.append({key_group_id: item}) # Add entity without linking single_entity_dict = {} for class_name, single_items in single_entity.items(): single_entity_dict[class_name] = [] for single_item in single_items: single_entity_dict[class_name].append({ 'text': single_item['text'], 'id': single_item['group_id'], 'class': class_name, 'bbox': single_item['bbox'] }) if len(table) > 0: table = sorted(table, key=lambda x: list(x.keys())[0]) table = [v for item in table for k, v in item.items()] outputs = {} outputs['title'] = sorted( single_entity_dict["title"], key=lambda x: x["id"] ) outputs['key'] = sorted( single_entity_dict["key"], key=lambda x: x["id"] ) outputs['value'] = sorted( single_entity_dict["value"], key=lambda x: x["id"] ) outputs['single'] = sorted(single_pairs, key=lambda x: int(float(list(x.values())[0]['id']))) outputs['triplet'] = triplet_pairs outputs['table'] = table return outputs def export_sbt_outputs( lwords, lbboxes, class_words, lrelations, labels, ): word_groups = merged_token_to_wordgroup(class_words, lwords, lbboxes, labels) linking_pairs = matched_wordgroup_relations(word_groups, lrelations) date_key_value_pairs = get_pairs( linking_pairs, rel_from="date_key", rel_to="date_value" ) # => {date_value_group_id: [date_key_group_id, date_value_group_id]} # product_name_code_pairs = get_pairs( # linking_pairs, rel_to="product_name", rel_from="product_code" # ) # => {product_name_group_id: [product_code_group_id, product_name_group_id]} # product_name_quantity_pairs = get_pairs( # linking_pairs, rel_to="product_name", rel_from="quantity" # ) # => {product_name_group_id: [quantity_group_id, product_name_group_id]} serial_key_value_pairs = get_pairs( linking_pairs, rel_from="sn_key", rel_to="sn_value" ) # => {sn_value_group_id: [sn_key_group_id, sn_value_group_id]} sold_key_value_pairs = get_pairs( linking_pairs, rel_from="sold_key", rel_to="sold_value" ) # => {sold_value_group_id: [sold_key_group_id, sold_value_group_id]} single_entity = get_single_entity(word_groups, lrelations, labels=labels) date_value = [] sold_value = [] serial_imei = [] table = [] # print('key2values_relations', key2values_relations) date_relations = get_key2values_relations(date_key_value_pairs) for key_group_id, list_value_group_id in date_relations.items(): for value_group_id in list_value_group_id: date_value.append( { "text": word_groups[value_group_id]["text"], "id": value_group_id, "class": "date_value", "bbox": word_groups[value_group_id]["bbox"], "key_bbox": word_groups[key_group_id]["bbox"], "raw_key_name": word_groups[key_group_id]["text"], } ) sold_relations = get_key2values_relations(sold_key_value_pairs) for key_group_id, list_value_group_id in sold_relations.items(): for value_group_id in list_value_group_id: sold_value.append( { "text": word_groups[value_group_id]["text"], "id": value_group_id, "class": "sold_value", "bbox": word_groups[value_group_id]["bbox"], "key_bbox": word_groups[key_group_id]["bbox"], "raw_key_name": word_groups[key_group_id]["text"], } ) serial_relations = get_key2values_relations(serial_key_value_pairs) for key_group_id, list_value_group_id in serial_relations.items(): for value_group_id in list_value_group_id: serial_imei.append( { "text": word_groups[value_group_id]["text"], "id": value_group_id, "class": "sn_value", "bbox": word_groups[value_group_id]["bbox"], "key_bbox": word_groups[key_group_id]["bbox"], "raw_key_name": word_groups[key_group_id]["text"], } ) single_entity_dict = {} for class_name, single_items in single_entity.items(): single_entity_dict[class_name] = [] for single_item in single_items: single_entity_dict[class_name].append( { "text": single_item["text"], "id": single_item["group_id"], "class": class_name, "bbox": single_item["bbox"], } ) # list_product_name_group_ids = set( # list(product_name_code_pairs.keys()) # + list(product_name_quantity_pairs.keys()) # + [x["id"] for x in single_entity_dict["product_name"]] # ) # for product_name_group_id in list_product_name_group_ids: # item = {"productname": [], "modelnumber": [], "qty": []} # item["productname"].append( # { # "text": word_groups[product_name_group_id]["text"], # "id": product_name_group_id, # "class": "product_name", # "bbox": word_groups[product_name_group_id]["bbox"], # } # ) # if product_name_group_id in product_name_code_pairs: # product_code_group_id = product_name_code_pairs[product_name_group_id][0] # item["modelnumber"].append( # { # "text": word_groups[product_code_group_id]["text"], # "id": product_code_group_id, # "class": "product_code", # "bbox": word_groups[product_code_group_id]["bbox"], # } # ) # if product_name_group_id in product_name_quantity_pairs: # quantity_group_id = product_name_quantity_pairs[product_name_group_id][0] # item["qty"].append( # { # "text": word_groups[quantity_group_id]["text"], # "id": quantity_group_id, # "class": "quantity", # "bbox": word_groups[quantity_group_id]["bbox"], # } # ) # table.append(item) # if len(table) > 0: # table = sorted(table, key=lambda x: x["productname"][0]["id"]) if len(serial_imei) > 0: serial_imei = sorted(serial_imei, key=lambda x: x["id"]) outputs = {} outputs["seller"] = single_entity_dict["seller"] outputs["voucher"] = single_entity_dict["voucher"] outputs["website"] = single_entity_dict["website"] outputs["hotline"] = single_entity_dict["hotline"] outputs["sold_value"] = sold_value + single_entity_dict["sold_key"] + single_entity_dict["sold_value"] outputs["date_value"] = date_value + single_entity_dict["date_value"] + single_entity_dict["date_key"] outputs["serial_imei"] = serial_imei + single_entity_dict["sn_value"] + single_entity_dict["sn_key"] # outputs["table"] = table return outputs