import os from typing import Any import numpy as np import pandas as pd import imagesize import itertools from PIL import Image import argparse import torch from utils.utils import read_ocr_result_from_txt, read_json, post_process_basic_ocr from utils.run_ocr import load_ocr_engine, process_img from lightning_modules.utils import sliding_windows class KVUProcess: def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0): self.tokenizer_layoutxlm = tokenizer_layoutxlm self.feature_extractor = feature_extractor self.max_seq_length = max_seq_length self.backbone_type = backbone_type self.class_names = class_names self.slice_interval = slice_interval self.window_size = window_size self.run_ocr = run_ocr self.mode = mode self.pad_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._pad_token) self.cls_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._cls_token) self.sep_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._sep_token) self.unk_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._unk_token) self.class_idx_dic = dict( [(class_name, idx) for idx, class_name in enumerate(self.class_names)] ) self.ocr_engine = None if self.run_ocr == 1: self.ocr_engine = load_ocr_engine() def __call__(self, img_path: str, ocr_path: str) -> list: if (self.run_ocr == 1) or (not os.path.exists(ocr_path)): ocr_path = "tmp.txt" process_img(img_path, ocr_path, self.ocr_engine, export_img=False) lbboxes, lwords = read_ocr_result_from_txt(ocr_path) lwords = post_process_basic_ocr(lwords) bbox_windows = sliding_windows(lbboxes, self.window_size, self.slice_interval) word_windows = sliding_windows(lwords, self.window_size, self.slice_interval) assert len(bbox_windows) == len(word_windows), f"Shape of lbboxes and lwords after sliding window is not the same {len(bbox_windows)} # {len(word_windows)}" width, height = imagesize.get(img_path) images = [Image.open(img_path).convert("RGB")] image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy()) if self.mode == 0: output = self.preprocess(lbboxes, lwords, {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, max_seq_length=self.max_seq_length) elif self.mode == 1: output = {} windows = [] for i in range(len(bbox_windows)): _words = word_windows[i] _bboxes = bbox_windows[i] windows.append( self.preprocess( _bboxes, _words, {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, max_seq_length=self.max_seq_length) ) output['windows'] = windows elif self.mode == 2: output = {} windows = [] output['doduments'] = self.preprocess(lbboxes, lwords, {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, max_seq_length=2048) for i in range(len(bbox_windows)): _words = word_windows[i] _bboxes = bbox_windows[i] windows.append( self.preprocess( _bboxes, _words, {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, max_seq_length=self.max_seq_length) ) output['windows'] = windows else: raise ValueError( f"Not supported mode: {self.mode }" ) return output def preprocess(self, bounding_boxes, words, feature_maps, max_seq_length): list_word_objects = [] for bb, text in zip(bounding_boxes, words): boundingBox = [[bb[0], bb[1]], [bb[2], bb[1]], [bb[2], bb[3]], [bb[0], bb[3]]] tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm.tokenize(text)) list_word_objects.append({ "layoutxlm_tokens": tokens, "boundingBox": boundingBox, "text": text }) ( bbox, input_ids, attention_mask, are_box_first_tokens, box_to_token_indices, box2token_span_map, lwords, len_valid_tokens, len_non_overlap_tokens, len_list_tokens ) = self.parser_words(list_word_objects, self.max_seq_length, feature_maps["width"], feature_maps["height"]) assert len_list_tokens == len_valid_tokens + 2 len_overlap_tokens = len_valid_tokens - len_non_overlap_tokens ntokens = max_seq_length if max_seq_length == 512 else len_valid_tokens + 2 input_ids = input_ids[:ntokens] attention_mask = attention_mask[:ntokens] bbox = bbox[:ntokens] are_box_first_tokens = are_box_first_tokens[:ntokens] input_ids = torch.from_numpy(input_ids) attention_mask = torch.from_numpy(attention_mask) bbox = torch.from_numpy(bbox) are_box_first_tokens = torch.from_numpy(are_box_first_tokens) len_valid_tokens = torch.tensor(len_valid_tokens) len_overlap_tokens = torch.tensor(len_overlap_tokens) return_dict = { "img_path": feature_maps['img_path'], "words": lwords, "len_overlap_tokens": len_overlap_tokens, 'len_valid_tokens': len_valid_tokens, "image": feature_maps['image'], "input_ids_layoutxlm": input_ids, "attention_mask_layoutxlm": attention_mask, "are_box_first_tokens": are_box_first_tokens, "bbox": bbox, } return return_dict def parser_words(self, words, max_seq_length, width, height): list_bbs = [] list_words = [] list_tokens = [] cls_bbs = [0.0] * 8 box2token_span_map = [] box_to_token_indices = [] lwords = [''] * max_seq_length cum_token_idx = 0 len_valid_tokens = 0 len_non_overlap_tokens = 0 input_ids = np.ones(max_seq_length, dtype=int) * self.pad_token_id_layoutxlm bbox = np.zeros((max_seq_length, 8), dtype=np.float32) attention_mask = np.zeros(max_seq_length, dtype=int) are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_) for word_idx, word in enumerate(words): this_box_token_indices = [] tokens = word["layoutxlm_tokens"] bb = word["boundingBox"] text = word["text"] len_valid_tokens += len(tokens) if word_idx < self.slice_interval: len_non_overlap_tokens += len(tokens) if len(tokens) == 0: tokens.append(self.unk_token_id) if len(list_tokens) + len(tokens) > max_seq_length - 2: break box2token_span_map.append( [len(list_tokens) + 1, len(list_tokens) + len(tokens) + 1] ) # including st_idx list_tokens += tokens # min, max clipping for coord_idx in range(4): bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], width)) bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], height)) bb = list(itertools.chain(*bb)) bbs = [bb for _ in range(len(tokens))] texts = [text for _ in range(len(tokens))] for _ in tokens: cum_token_idx += 1 this_box_token_indices.append(cum_token_idx) list_bbs.extend(bbs) list_words.extend(texts) #### box_to_token_indices.append(this_box_token_indices) sep_bbs = [width, height] * 4 # For [CLS] and [SEP] list_tokens = ( [self.cls_token_id_layoutxlm] + list_tokens[: max_seq_length - 2] + [self.sep_token_id_layoutxlm] ) if len(list_bbs) == 0: # When len(json_obj["words"]) == 0 (no OCR result) list_bbs = [cls_bbs] + [sep_bbs] else: # len(list_bbs) > 0 list_bbs = [cls_bbs] + list_bbs[: max_seq_length - 2] + [sep_bbs] # list_words = ['CLS'] + list_words[: max_seq_length - 2] + ['SEP'] ### # if len(list_words) < 510: # list_words.extend(['

' for _ in range(510 - len(list_words))]) list_words = [self.tokenizer_layoutxlm._cls_token] + list_words[: max_seq_length - 2] + [self.tokenizer_layoutxlm._sep_token] len_list_tokens = len(list_tokens) input_ids[:len_list_tokens] = list_tokens attention_mask[:len_list_tokens] = 1 bbox[:len_list_tokens, :] = list_bbs lwords[:len_list_tokens] = list_words # Normalize bbox -> 0 ~ 1 bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / width bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / height if self.backbone_type in ("layoutlm", "layoutxlm"): bbox = bbox[:, [0, 1, 4, 5]] bbox = bbox * 1000 bbox = bbox.astype(int) else: assert False st_indices = [ indices[0] for indices in box_to_token_indices if indices[0] < max_seq_length ] are_box_first_tokens[st_indices] = True return ( bbox, input_ids, attention_mask, are_box_first_tokens, box_to_token_indices, box2token_span_map, lwords, len_valid_tokens, len_non_overlap_tokens, len_list_tokens ) def parser_entity_extraction(self, parse_class, box_to_token_indices, max_seq_length): itc_labels = np.zeros(max_seq_length, dtype=int) stc_labels = np.ones(max_seq_length, dtype=np.int64) * max_seq_length classes_dic = parse_class for class_name in self.class_names: if class_name == "others": continue if class_name not in classes_dic: continue for word_list in classes_dic[class_name]: is_first, last_word_idx = True, -1 for word_idx in word_list: if word_idx >= len(box_to_token_indices): break box2token_list = box_to_token_indices[word_idx] for converted_word_idx in box2token_list: if converted_word_idx >= max_seq_length: break # out of idx if is_first: itc_labels[converted_word_idx] = self.class_idx_dic[ class_name ] is_first, last_word_idx = False, converted_word_idx else: stc_labels[converted_word_idx] = last_word_idx last_word_idx = converted_word_idx return itc_labels, stc_labels def parser_entity_linking(self, parse_relation, itc_labels, box2token_span_map, max_seq_length): el_labels = np.ones(max_seq_length, dtype=int) * max_seq_length el_labels_from_key = np.ones(max_seq_length, dtype=int) * max_seq_length relations = parse_relation for relation in relations: if relation[0] >= len(box2token_span_map) or relation[1] >= len( box2token_span_map ): continue if ( box2token_span_map[relation[0]][0] >= max_seq_length or box2token_span_map[relation[1]][0] >= max_seq_length ): continue word_from = box2token_span_map[relation[0]][0] word_to = box2token_span_map[relation[1]][0] # el_labels[word_to] = word_from if el_labels[word_to] != 512 and el_labels_from_key[word_to] != 512: continue if itc_labels[word_from] == 2 and itc_labels[word_to] == 3: el_labels_from_key[word_to] = word_from # pair of (key-value) if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)): el_labels[word_to] = word_from # pair of (header, key) or (header-value) return el_labels, el_labels_from_key class DocumentKVUProcess(KVUProcess): def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, max_window_count, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0): super().__init__(tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length, mode) self.max_window_count = max_window_count self.pad_token_id = self.pad_token_id_layoutxlm self.cls_token_id = self.cls_token_id_layoutxlm self.sep_token_id = self.sep_token_id_layoutxlm self.unk_token_id = self.unk_token_id_layoutxlm self.tokenizer = self.tokenizer_layoutxlm def __call__(self, img_path: str, ocr_path: str) -> list: if (self.run_ocr == 1) and (not os.path.exists(ocr_path)): ocr_path = "tmp.txt" process_img(img_path, ocr_path, self.ocr_engine, export_img=False) lbboxes, lwords = read_ocr_result_from_txt(ocr_path) lwords = post_process_basic_ocr(lwords) width, height = imagesize.get(img_path) images = [Image.open(img_path).convert("RGB")] image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy()) output = self.preprocess(lbboxes, lwords, {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, self.max_seq_length) return output def preprocess(self, bounding_boxes, words, feature_maps, max_seq_length): n_words = len(words) output_dicts = {'windows': [], 'documents': []} n_empty_windows = 0 for i in range(self.max_window_count): input_ids = np.ones(self.max_seq_length, dtype=int) * self.pad_token_id bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32) attention_mask = np.zeros(self.max_seq_length, dtype=int) are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_) if n_words == 0: n_empty_windows += 1 output_dicts['windows'].append({ "image": feature_maps['image'], "input_ids_layoutxlm": torch.from_numpy(input_ids), "bbox": torch.from_numpy(bbox), "words": [], "attention_mask_layoutxlm": torch.from_numpy(attention_mask), "are_box_first_tokens": torch.from_numpy(are_box_first_tokens), }) continue start_word_idx = i * self.window_size stop_word_idx = min(n_words, (i+1)*self.window_size) if start_word_idx >= stop_word_idx: n_empty_windows += 1 output_dicts['windows'].append(output_dicts['windows'][-1]) continue list_word_objects = [] for bb, text in zip(bounding_boxes[start_word_idx:stop_word_idx], words[start_word_idx:stop_word_idx]): boundingBox = [[bb[0], bb[1]], [bb[2], bb[1]], [bb[2], bb[3]], [bb[0], bb[3]]] tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm.tokenize(text)) list_word_objects.append({ "layoutxlm_tokens": tokens, "boundingBox": boundingBox, "text": text }) ( bbox, input_ids, attention_mask, are_box_first_tokens, box_to_token_indices, box2token_span_map, lwords, len_valid_tokens, len_non_overlap_tokens, len_list_layoutxlm_tokens ) = self.parser_words(list_word_objects, self.max_seq_length, feature_maps["width"], feature_maps["height"]) input_ids = torch.from_numpy(input_ids) bbox = torch.from_numpy(bbox) attention_mask = torch.from_numpy(attention_mask) are_box_first_tokens = torch.from_numpy(are_box_first_tokens) return_dict = { "image": feature_maps['image'], "input_ids_layoutxlm": input_ids, "bbox": bbox, "words": lwords, "attention_mask_layoutxlm": attention_mask, "are_box_first_tokens": are_box_first_tokens, } output_dicts["windows"].append(return_dict) attention_mask = torch.cat([o['attention_mask_layoutxlm'] for o in output_dicts["windows"]]) are_box_first_tokens = torch.cat([o['are_box_first_tokens'] for o in output_dicts["windows"]]) if n_empty_windows > 0: attention_mask[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=int)) are_box_first_tokens[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=np.bool_)) bbox = torch.cat([o['bbox'] for o in output_dicts["windows"]]) words = [] for o in output_dicts['windows']: words.extend(o['words']) return_dict = { "attention_mask_layoutxlm": attention_mask, "bbox": bbox, "are_box_first_tokens": are_box_first_tokens, "n_empty_windows": n_empty_windows, "words": words } output_dicts['documents'] = return_dict return output_dicts