import os from typing import Any import numpy as np import pandas as pd import imagesize import itertools from PIL import Image import argparse import torch from utils.utils import read_ocr_result_from_txt, read_json, post_process_basic_ocr from utils.run_ocr import load_ocr_engine, process_img from lightning_modules.utils import sliding_windows class KVUProcess: def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0): self.tokenizer_layoutxlm = tokenizer_layoutxlm self.feature_extractor = feature_extractor self.max_seq_length = max_seq_length self.backbone_type = backbone_type self.class_names = class_names self.slice_interval = slice_interval self.window_size = window_size self.run_ocr = run_ocr self.mode = mode self.pad_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._pad_token) self.cls_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._cls_token) self.sep_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._sep_token) self.unk_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._unk_token) self.class_idx_dic = dict( [(class_name, idx) for idx, class_name in enumerate(self.class_names)] ) self.ocr_engine = None if self.run_ocr == 1: self.ocr_engine = load_ocr_engine() def __call__(self, img_path: str, ocr_path: str) -> list: if (self.run_ocr == 1) or (not os.path.exists(ocr_path)): process_img(img_path, "tmp.txt", self.ocr_engine, export_img=False) ocr_path = "tmp.txt" lbboxes, lwords = read_ocr_result_from_txt(ocr_path) lwords = post_process_basic_ocr(lwords) bbox_windows = sliding_windows(lbboxes, self.window_size, self.slice_interval) word_windows = sliding_windows(lwords, self.window_size, self.slice_interval) assert len(bbox_windows) == len(word_windows), f"Shape of lbboxes and lwords after sliding window is not the same {len(bbox_windows)} # {len(word_windows)}" width, height = imagesize.get(img_path) images = [Image.open(img_path).convert("RGB")] image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy()) if self.mode == 0: output = self.preprocess(lbboxes, lwords, {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, max_seq_length=self.max_seq_length) elif self.mode == 1: output = {} windows = [] for i in range(len(bbox_windows)): _words = word_windows[i] _bboxes = bbox_windows[i] windows.append( self.preprocess( _bboxes, _words, {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, max_seq_length=self.max_seq_length) ) output['windows'] = windows elif self.mode == 2: output = {} windows = [] output['doduments'] = self.preprocess(lbboxes, lwords, {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, max_seq_length=2048) for i in range(len(bbox_windows)): _words = word_windows[i] _bboxes = bbox_windows[i] windows.append( self.preprocess( _bboxes, _words, {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, max_seq_length=self.max_seq_length) ) output['windows'] = windows else: raise ValueError( f"Not supported mode: {self.mode }" ) return output def preprocess(self, bounding_boxes, words, feature_maps, max_seq_length): input_ids_layoutxlm = np.ones(max_seq_length, dtype=int) * self.pad_token_id_layoutxlm attention_mask_layoutxlm = np.zeros(max_seq_length, dtype=int) bbox = np.zeros((max_seq_length, 8), dtype=np.float32) are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_) list_layoutxlm_tokens = [] list_bbs = [] list_words = [] lwords = [''] * max_seq_length box_to_token_indices = [] cum_token_idx = 0 cls_bbs = [0.0] * 8 len_overlap_tokens = 0 len_non_overlap_tokens = 0 len_valid_tokens = 0 for word_idx, (bounding_box, word) in enumerate(zip(bounding_boxes, words)): bb = [[bounding_box[0], bounding_box[1]], [bounding_box[2], bounding_box[1]], [bounding_box[2], bounding_box[3]], [bounding_box[0], bounding_box[3]]] layoutxlm_tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm.tokenize(word)) this_box_token_indices = [] len_valid_tokens += len(layoutxlm_tokens) if word_idx < self.slice_interval: len_non_overlap_tokens += len(layoutxlm_tokens) if len(layoutxlm_tokens) == 0: layoutxlm_tokens.append(self.unk_token_id) if len(list_layoutxlm_tokens) + len(layoutxlm_tokens) > max_seq_length - 2: break list_layoutxlm_tokens += layoutxlm_tokens # min, max clipping for coord_idx in range(4): bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], feature_maps['width'])) bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], feature_maps['height'])) bb = list(itertools.chain(*bb)) bbs = [bb for _ in range(len(layoutxlm_tokens))] texts = [word for _ in range(len(layoutxlm_tokens))] for _ in layoutxlm_tokens: cum_token_idx += 1 this_box_token_indices.append(cum_token_idx) list_bbs.extend(bbs) list_words.extend(texts) ### box_to_token_indices.append(this_box_token_indices) sep_bbs = [feature_maps['width'], feature_maps['height']] * 4 # For [CLS] and [SEP] list_layoutxlm_tokens = ( [self.cls_token_id_layoutxlm] + list_layoutxlm_tokens[: max_seq_length - 2] + [self.sep_token_id_layoutxlm] ) if len(list_bbs) == 0: # When len(json_obj["words"]) == 0 (no OCR result) list_bbs = [cls_bbs] + [sep_bbs] else: # len(list_bbs) > 0 list_bbs = [cls_bbs] + list_bbs[: max_seq_length - 2] + [sep_bbs] # list_words = ['CLS'] + list_words[: max_seq_length - 2] + ['SEP'] list_words = [self.tokenizer_layoutxlm._cls_token] + list_words[: max_seq_length - 2] + [self.tokenizer_layoutxlm._sep_token] len_list_layoutxlm_tokens = len(list_layoutxlm_tokens) input_ids_layoutxlm[:len_list_layoutxlm_tokens] = list_layoutxlm_tokens attention_mask_layoutxlm[:len_list_layoutxlm_tokens] = 1 bbox[:len_list_layoutxlm_tokens, :] = list_bbs lwords[:len_list_layoutxlm_tokens] = list_words ### # Normalize bbox -> 0 ~ 1 bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / feature_maps['width'] bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / feature_maps['height'] if self.backbone_type in ("layoutlm", "layoutxlm", "xlm-roberta"): bbox = bbox[:, [0, 1, 4, 5]] bbox = bbox * 1000 bbox = bbox.astype(int) else: assert False st_indices = [ indices[0] for indices in box_to_token_indices if indices[0] < max_seq_length ] are_box_first_tokens[st_indices] = True assert len_list_layoutxlm_tokens == len_valid_tokens + 2 len_overlap_tokens = len_valid_tokens - len_non_overlap_tokens ntokens = max_seq_length if max_seq_length == 512 else len_valid_tokens + 2 input_ids_layoutxlm = input_ids_layoutxlm[:ntokens] attention_mask_layoutxlm = attention_mask_layoutxlm[:ntokens] bbox = bbox[:ntokens] are_box_first_tokens = are_box_first_tokens[:ntokens] input_ids_layoutxlm = torch.from_numpy(input_ids_layoutxlm) attention_mask_layoutxlm = torch.from_numpy(attention_mask_layoutxlm) bbox = torch.from_numpy(bbox) are_box_first_tokens = torch.from_numpy(are_box_first_tokens) len_valid_tokens = torch.tensor(len_valid_tokens) len_overlap_tokens = torch.tensor(len_overlap_tokens) return_dict = { "img_path": feature_maps['img_path'], "words": list_words, "len_overlap_tokens": len_overlap_tokens, 'len_valid_tokens': len_valid_tokens, "image": feature_maps['image'], "input_ids_layoutxlm": input_ids_layoutxlm, "attention_mask_layoutxlm": attention_mask_layoutxlm, "are_box_first_tokens": are_box_first_tokens, "bbox": bbox, } return return_dict def load_ground_truth(self, json_file): json_obj = read_json(json_file) width = json_obj["meta"]["imageSize"]["width"] height = json_obj["meta"]["imageSize"]["height"] input_ids = np.ones(self.max_seq_length, dtype=int) * self.pad_token_id_layoutxlm bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32) attention_mask = np.zeros(self.max_seq_length, dtype=int) itc_labels = np.zeros(self.max_seq_length, dtype=int) are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_) # stc_labels stores the index of the previous token. # A stored index of max_seq_length (512) indicates that # this token is the initial token of a word box. stc_labels = np.ones(self.max_seq_length, dtype=np.int64) * self.max_seq_length el_labels = np.ones(self.max_seq_length, dtype=int) * self.max_seq_length el_labels_from_key = np.ones(self.max_seq_length, dtype=int) * self.max_seq_length list_tokens = [] list_bbs = [] list_words = [] box2token_span_map = [] lwords = [''] * self.max_seq_length box_to_token_indices = [] cum_token_idx = 0 cls_bbs = [0.0] * 8 for word_idx, word in enumerate(json_obj["words"]): this_box_token_indices = [] tokens = word["layoutxlm_tokens"] bb = word["boundingBox"] text = word["text"] if len(tokens) == 0: tokens.append(self.unk_token_id) if len(list_tokens) + len(tokens) > self.max_seq_length - 2: break box2token_span_map.append( [len(list_tokens) + 1, len(list_tokens) + len(tokens) + 1] ) # including st_idx list_tokens += tokens # min, max clipping for coord_idx in range(4): bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], width)) bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], height)) bb = list(itertools.chain(*bb)) bbs = [bb for _ in range(len(tokens))] texts = [text for _ in range(len(tokens))] for _ in tokens: cum_token_idx += 1 this_box_token_indices.append(cum_token_idx) list_bbs.extend(bbs) list_words.extend(texts) #### box_to_token_indices.append(this_box_token_indices) sep_bbs = [width, height] * 4 # For [CLS] and [SEP] list_tokens = ( [self.cls_token_id_layoutxlm] + list_tokens[: self.max_seq_length - 2] + [self.sep_token_id_layoutxlm] ) if len(list_bbs) == 0: # When len(json_obj["words"]) == 0 (no OCR result) list_bbs = [cls_bbs] + [sep_bbs] else: # len(list_bbs) > 0 list_bbs = [cls_bbs] + list_bbs[: self.max_seq_length - 2] + [sep_bbs] # list_words = ['CLS'] + list_words[: self.max_seq_length - 2] + ['SEP'] ### list_words = [self.tokenizer_layoutxlm._cls_token] + list_words[: self.max_seq_length - 2] + [self.tokenizer_layoutxlm._sep_token] len_list_tokens = len(list_tokens) input_ids[:len_list_tokens] = list_tokens attention_mask[:len_list_tokens] = 1 bbox[:len_list_tokens, :] = list_bbs lwords[:len_list_tokens] = list_words # Normalize bbox -> 0 ~ 1 bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / width bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / height if self.backbone_type in ("layoutlm", "layoutxlm"): bbox = bbox[:, [0, 1, 4, 5]] bbox = bbox * 1000 bbox = bbox.astype(int) else: assert False st_indices = [ indices[0] for indices in box_to_token_indices if indices[0] < self.max_seq_length ] are_box_first_tokens[st_indices] = True # Label classes_dic = json_obj["parse"]["class"] for class_name in self.class_names: if class_name == "others": continue if class_name not in classes_dic: continue for word_list in classes_dic[class_name]: is_first, last_word_idx = True, -1 for word_idx in word_list: if word_idx >= len(box_to_token_indices): break box2token_list = box_to_token_indices[word_idx] for converted_word_idx in box2token_list: if converted_word_idx >= self.max_seq_length: break # out of idx if is_first: itc_labels[converted_word_idx] = self.class_idx_dic[ class_name ] is_first, last_word_idx = False, converted_word_idx else: stc_labels[converted_word_idx] = last_word_idx last_word_idx = converted_word_idx # Label relations = json_obj["parse"]["relations"] for relation in relations: if relation[0] >= len(box2token_span_map) or relation[1] >= len( box2token_span_map ): continue if ( box2token_span_map[relation[0]][0] >= self.max_seq_length or box2token_span_map[relation[1]][0] >= self.max_seq_length ): continue word_from = box2token_span_map[relation[0]][0] word_to = box2token_span_map[relation[1]][0] # el_labels[word_to] = word_from if el_labels[word_to] != 512 and el_labels_from_key[word_to] != 512: continue # if self.second_relations == 1: # if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)): # el_labels[word_to] = word_from # pair of (header, key) or (header-value) # else: #### 1st relation => ['key, 'value'] #### 2st relation => ['header', 'key'or'value'] if itc_labels[word_from] == 2 and itc_labels[word_to] == 3: el_labels_from_key[word_to] = word_from # pair of (key-value) if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)): el_labels[word_to] = word_from # pair of (header, key) or (header-value) input_ids = torch.from_numpy(input_ids) bbox = torch.from_numpy(bbox) attention_mask = torch.from_numpy(attention_mask) itc_labels = torch.from_numpy(itc_labels) are_box_first_tokens = torch.from_numpy(are_box_first_tokens) stc_labels = torch.from_numpy(stc_labels) el_labels = torch.from_numpy(el_labels) el_labels_from_key = torch.from_numpy(el_labels_from_key) return_dict = { # "image": feature_maps, "input_ids": input_ids, "bbox": bbox, "words": lwords, "attention_mask": attention_mask, "itc_labels": itc_labels, "are_box_first_tokens": are_box_first_tokens, "stc_labels": stc_labels, "el_labels": el_labels, "el_labels_from_key": el_labels_from_key } return return_dict class DocumentKVUProcess(KVUProcess): def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, max_window_count, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0): super().__init__(tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length, mode) self.max_window_count = max_window_count self.pad_token_id = self.pad_token_id_layoutxlm self.cls_token_id = self.cls_token_id_layoutxlm self.sep_token_id = self.sep_token_id_layoutxlm self.unk_token_id = self.unk_token_id_layoutxlm self.tokenizer = self.tokenizer_layoutxlm def __call__(self, img_path: str, ocr_path: str) -> list: if (self.run_ocr == 1) and (not os.path.exists(ocr_path)): process_img(img_path, "tmp.txt", self.ocr_engine, export_img=False) ocr_path = "tmp.txt" lbboxes, lwords = read_ocr_result_from_txt(ocr_path) lwords = post_process_basic_ocr(lwords) width, height = imagesize.get(img_path) images = [Image.open(img_path).convert("RGB")] image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy()) output = self.preprocess(lbboxes, lwords, {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}) return output def preprocess(self, bounding_boxes, words, feature_maps): n_words = len(words) output_dicts = {'windows': [], 'documents': []} n_empty_windows = 0 for i in range(self.max_window_count): input_ids = np.ones(self.max_seq_length, dtype=int) * self.pad_token_id bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32) attention_mask = np.zeros(self.max_seq_length, dtype=int) are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_) if n_words == 0: n_empty_windows += 1 output_dicts['windows'].append({ "image": feature_maps['image'], "input_ids": torch.from_numpy(input_ids), "bbox": torch.from_numpy(bbox), "words": [], "attention_mask": torch.from_numpy(attention_mask), "are_box_first_tokens": torch.from_numpy(are_box_first_tokens), }) continue start_word_idx = i * self.window_size stop_word_idx = min(n_words, (i+1)*self.window_size) if start_word_idx >= stop_word_idx: n_empty_windows += 1 output_dicts['windows'].append(output_dicts['windows'][-1]) continue list_tokens = [] list_bbs = [] list_words = [] lwords = [''] * self.max_seq_length box_to_token_indices = [] cum_token_idx = 0 cls_bbs = [0.0] * 8 for _, (bounding_box, word) in enumerate(zip(bounding_boxes[start_word_idx:stop_word_idx], words[start_word_idx:stop_word_idx])): bb = [[bounding_box[0], bounding_box[1]], [bounding_box[2], bounding_box[1]], [bounding_box[2], bounding_box[3]], [bounding_box[0], bounding_box[3]]] tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(word)) this_box_token_indices = [] if len(tokens) == 0: tokens.append(self.unk_token_id) if len(list_tokens) + len(tokens) > self.max_seq_length - 2: break list_tokens += tokens # min, max clipping for coord_idx in range(4): bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], feature_maps['width'])) bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], feature_maps['height'])) bb = list(itertools.chain(*bb)) bbs = [bb for _ in range(len(tokens))] texts = [word for _ in range(len(tokens))] for _ in tokens: cum_token_idx += 1 this_box_token_indices.append(cum_token_idx) list_bbs.extend(bbs) list_words.extend(texts) ### box_to_token_indices.append(this_box_token_indices) sep_bbs = [feature_maps['width'], feature_maps['height']] * 4 # For [CLS] and [SEP] list_tokens = ( [self.cls_token_id] + list_tokens[: self.max_seq_length - 2] + [self.sep_token_id] ) if len(list_bbs) == 0: # When len(json_obj["words"]) == 0 (no OCR result) list_bbs = [cls_bbs] + [sep_bbs] else: # len(list_bbs) > 0 list_bbs = [cls_bbs] + list_bbs[: self.max_seq_length - 2] + [sep_bbs] if len(list_words) < 510: list_words.extend(['

' for _ in range(510 - len(list_words))]) list_words = [self.tokenizer._cls_token] + list_words[: self.max_seq_length - 2] + [self.tokenizer._sep_token] len_list_tokens = len(list_tokens) input_ids[:len_list_tokens] = list_tokens attention_mask[:len_list_tokens] = 1 bbox[:len_list_tokens, :] = list_bbs lwords[:len_list_tokens] = list_words ### # Normalize bbox -> 0 ~ 1 bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / feature_maps['width'] bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / feature_maps['height'] bbox = bbox[:, [0, 1, 4, 5]] bbox = bbox * 1000 bbox = bbox.astype(int) st_indices = [ indices[0] for indices in box_to_token_indices if indices[0] < self.max_seq_length ] are_box_first_tokens[st_indices] = True input_ids = torch.from_numpy(input_ids) bbox = torch.from_numpy(bbox) attention_mask = torch.from_numpy(attention_mask) are_box_first_tokens = torch.from_numpy(are_box_first_tokens) return_dict = { "image": feature_maps['image'], "input_ids": input_ids, "bbox": bbox, "words": list_words, "attention_mask": attention_mask, "are_box_first_tokens": are_box_first_tokens, } output_dicts["windows"].append(return_dict) attention_mask = torch.cat([o['attention_mask'] for o in output_dicts["windows"]]) are_box_first_tokens = torch.cat([o['are_box_first_tokens'] for o in output_dicts["windows"]]) if n_empty_windows > 0: attention_mask[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=int)) are_box_first_tokens[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=np.bool_)) bbox = torch.cat([o['bbox'] for o in output_dicts["windows"]]) words = [] for o in output_dicts['windows']: words.extend(o['words']) return_dict = { "attention_mask": attention_mask, "bbox": bbox, "are_box_first_tokens": are_box_first_tokens, "n_empty_windows": n_empty_windows, "words": words } output_dicts['documents'] = return_dict return output_dicts