601 lines
24 KiB
Python
Executable File
601 lines
24 KiB
Python
Executable File
import os
|
|
from typing import Any
|
|
import numpy as np
|
|
import pandas as pd
|
|
import imagesize
|
|
import itertools
|
|
from PIL import Image
|
|
import argparse
|
|
|
|
import torch
|
|
|
|
from utils.utils import read_ocr_result_from_txt, read_json, post_process_basic_ocr
|
|
from utils.run_ocr import load_ocr_engine, process_img
|
|
from lightning_modules.utils import sliding_windows
|
|
|
|
|
|
class KVUProcess:
|
|
def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0):
|
|
self.tokenizer_layoutxlm = tokenizer_layoutxlm
|
|
self.feature_extractor = feature_extractor
|
|
|
|
self.max_seq_length = max_seq_length
|
|
self.backbone_type = backbone_type
|
|
self.class_names = class_names
|
|
|
|
self.slice_interval = slice_interval
|
|
self.window_size = window_size
|
|
self.run_ocr = run_ocr
|
|
self.mode = mode
|
|
|
|
self.pad_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._pad_token)
|
|
self.cls_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._cls_token)
|
|
self.sep_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._sep_token)
|
|
self.unk_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._unk_token)
|
|
|
|
|
|
self.class_idx_dic = dict(
|
|
[(class_name, idx) for idx, class_name in enumerate(self.class_names)]
|
|
)
|
|
self.ocr_engine = None
|
|
if self.run_ocr == 1:
|
|
self.ocr_engine = load_ocr_engine()
|
|
|
|
def __call__(self, img_path: str, ocr_path: str) -> list:
|
|
if (self.run_ocr == 1) or (not os.path.exists(ocr_path)):
|
|
process_img(img_path, "tmp.txt", self.ocr_engine, export_img=False)
|
|
ocr_path = "tmp.txt"
|
|
lbboxes, lwords = read_ocr_result_from_txt(ocr_path)
|
|
lwords = post_process_basic_ocr(lwords)
|
|
bbox_windows = sliding_windows(lbboxes, self.window_size, self.slice_interval)
|
|
word_windows = sliding_windows(lwords, self.window_size, self.slice_interval)
|
|
assert len(bbox_windows) == len(word_windows), f"Shape of lbboxes and lwords after sliding window is not the same {len(bbox_windows)} # {len(word_windows)}"
|
|
|
|
width, height = imagesize.get(img_path)
|
|
images = [Image.open(img_path).convert("RGB")]
|
|
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
|
|
|
|
|
|
if self.mode == 0:
|
|
output = self.preprocess(lbboxes, lwords,
|
|
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
|
|
max_seq_length=self.max_seq_length)
|
|
elif self.mode == 1:
|
|
output = {}
|
|
windows = []
|
|
for i in range(len(bbox_windows)):
|
|
_words = word_windows[i]
|
|
_bboxes = bbox_windows[i]
|
|
windows.append(
|
|
self.preprocess(
|
|
_bboxes, _words,
|
|
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
|
|
max_seq_length=self.max_seq_length)
|
|
)
|
|
|
|
output['windows'] = windows
|
|
elif self.mode == 2:
|
|
output = {}
|
|
windows = []
|
|
output['doduments'] = self.preprocess(lbboxes, lwords,
|
|
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
|
|
max_seq_length=2048)
|
|
for i in range(len(bbox_windows)):
|
|
_words = word_windows[i]
|
|
_bboxes = bbox_windows[i]
|
|
windows.append(
|
|
self.preprocess(
|
|
_bboxes, _words,
|
|
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
|
|
max_seq_length=self.max_seq_length)
|
|
)
|
|
|
|
output['windows'] = windows
|
|
else:
|
|
raise ValueError(
|
|
f"Not supported mode: {self.mode }"
|
|
)
|
|
return output
|
|
|
|
|
|
def preprocess(self, bounding_boxes, words, feature_maps, max_seq_length):
|
|
input_ids_layoutxlm = np.ones(max_seq_length, dtype=int) * self.pad_token_id_layoutxlm
|
|
|
|
attention_mask_layoutxlm = np.zeros(max_seq_length, dtype=int)
|
|
|
|
bbox = np.zeros((max_seq_length, 8), dtype=np.float32)
|
|
|
|
are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_)
|
|
|
|
list_layoutxlm_tokens = []
|
|
|
|
list_bbs = []
|
|
list_words = []
|
|
lwords = [''] * max_seq_length
|
|
|
|
box_to_token_indices = []
|
|
cum_token_idx = 0
|
|
|
|
cls_bbs = [0.0] * 8
|
|
len_overlap_tokens = 0
|
|
len_non_overlap_tokens = 0
|
|
len_valid_tokens = 0
|
|
|
|
for word_idx, (bounding_box, word) in enumerate(zip(bounding_boxes, words)):
|
|
bb = [[bounding_box[0], bounding_box[1]], [bounding_box[2], bounding_box[1]], [bounding_box[2], bounding_box[3]], [bounding_box[0], bounding_box[3]]]
|
|
layoutxlm_tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm.tokenize(word))
|
|
|
|
this_box_token_indices = []
|
|
|
|
len_valid_tokens += len(layoutxlm_tokens)
|
|
if word_idx < self.slice_interval:
|
|
len_non_overlap_tokens += len(layoutxlm_tokens)
|
|
|
|
if len(layoutxlm_tokens) == 0:
|
|
layoutxlm_tokens.append(self.unk_token_id)
|
|
|
|
|
|
if len(list_layoutxlm_tokens) + len(layoutxlm_tokens) > max_seq_length - 2:
|
|
break
|
|
|
|
|
|
list_layoutxlm_tokens += layoutxlm_tokens
|
|
|
|
# min, max clipping
|
|
for coord_idx in range(4):
|
|
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], feature_maps['width']))
|
|
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], feature_maps['height']))
|
|
|
|
bb = list(itertools.chain(*bb))
|
|
bbs = [bb for _ in range(len(layoutxlm_tokens))]
|
|
texts = [word for _ in range(len(layoutxlm_tokens))]
|
|
|
|
for _ in layoutxlm_tokens:
|
|
cum_token_idx += 1
|
|
this_box_token_indices.append(cum_token_idx)
|
|
|
|
list_bbs.extend(bbs)
|
|
list_words.extend(texts) ###
|
|
box_to_token_indices.append(this_box_token_indices)
|
|
|
|
sep_bbs = [feature_maps['width'], feature_maps['height']] * 4
|
|
|
|
# For [CLS] and [SEP]
|
|
list_layoutxlm_tokens = (
|
|
[self.cls_token_id_layoutxlm]
|
|
+ list_layoutxlm_tokens[: max_seq_length - 2]
|
|
+ [self.sep_token_id_layoutxlm]
|
|
)
|
|
|
|
if len(list_bbs) == 0:
|
|
# When len(json_obj["words"]) == 0 (no OCR result)
|
|
list_bbs = [cls_bbs] + [sep_bbs]
|
|
else: # len(list_bbs) > 0
|
|
list_bbs = [cls_bbs] + list_bbs[: max_seq_length - 2] + [sep_bbs]
|
|
# list_words = ['CLS'] + list_words[: max_seq_length - 2] + ['SEP']
|
|
list_words = [self.tokenizer_layoutxlm._cls_token] + list_words[: max_seq_length - 2] + [self.tokenizer_layoutxlm._sep_token]
|
|
|
|
len_list_layoutxlm_tokens = len(list_layoutxlm_tokens)
|
|
input_ids_layoutxlm[:len_list_layoutxlm_tokens] = list_layoutxlm_tokens
|
|
attention_mask_layoutxlm[:len_list_layoutxlm_tokens] = 1
|
|
|
|
|
|
bbox[:len_list_layoutxlm_tokens, :] = list_bbs
|
|
lwords[:len_list_layoutxlm_tokens] = list_words ###
|
|
|
|
# Normalize bbox -> 0 ~ 1
|
|
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / feature_maps['width']
|
|
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / feature_maps['height']
|
|
|
|
if self.backbone_type in ("layoutlm", "layoutxlm", "xlm-roberta"):
|
|
bbox = bbox[:, [0, 1, 4, 5]]
|
|
bbox = bbox * 1000
|
|
bbox = bbox.astype(int)
|
|
else:
|
|
assert False
|
|
|
|
st_indices = [
|
|
indices[0]
|
|
for indices in box_to_token_indices
|
|
if indices[0] < max_seq_length
|
|
]
|
|
are_box_first_tokens[st_indices] = True
|
|
|
|
assert len_list_layoutxlm_tokens == len_valid_tokens + 2
|
|
len_overlap_tokens = len_valid_tokens - len_non_overlap_tokens
|
|
|
|
ntokens = max_seq_length if max_seq_length == 512 else len_valid_tokens + 2
|
|
|
|
input_ids_layoutxlm = input_ids_layoutxlm[:ntokens]
|
|
attention_mask_layoutxlm = attention_mask_layoutxlm[:ntokens]
|
|
bbox = bbox[:ntokens]
|
|
are_box_first_tokens = are_box_first_tokens[:ntokens]
|
|
|
|
|
|
input_ids_layoutxlm = torch.from_numpy(input_ids_layoutxlm)
|
|
attention_mask_layoutxlm = torch.from_numpy(attention_mask_layoutxlm)
|
|
bbox = torch.from_numpy(bbox)
|
|
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
|
|
|
|
len_valid_tokens = torch.tensor(len_valid_tokens)
|
|
len_overlap_tokens = torch.tensor(len_overlap_tokens)
|
|
return_dict = {
|
|
"img_path": feature_maps['img_path'],
|
|
"words": list_words,
|
|
"len_overlap_tokens": len_overlap_tokens,
|
|
'len_valid_tokens': len_valid_tokens,
|
|
"image": feature_maps['image'],
|
|
"input_ids_layoutxlm": input_ids_layoutxlm,
|
|
"attention_mask_layoutxlm": attention_mask_layoutxlm,
|
|
"are_box_first_tokens": are_box_first_tokens,
|
|
"bbox": bbox,
|
|
}
|
|
return return_dict
|
|
|
|
def load_ground_truth(self, json_file):
|
|
json_obj = read_json(json_file)
|
|
width = json_obj["meta"]["imageSize"]["width"]
|
|
height = json_obj["meta"]["imageSize"]["height"]
|
|
|
|
input_ids = np.ones(self.max_seq_length, dtype=int) * self.pad_token_id_layoutxlm
|
|
bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32)
|
|
attention_mask = np.zeros(self.max_seq_length, dtype=int)
|
|
|
|
itc_labels = np.zeros(self.max_seq_length, dtype=int)
|
|
are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_)
|
|
|
|
# stc_labels stores the index of the previous token.
|
|
# A stored index of max_seq_length (512) indicates that
|
|
# this token is the initial token of a word box.
|
|
stc_labels = np.ones(self.max_seq_length, dtype=np.int64) * self.max_seq_length
|
|
el_labels = np.ones(self.max_seq_length, dtype=int) * self.max_seq_length
|
|
el_labels_from_key = np.ones(self.max_seq_length, dtype=int) * self.max_seq_length
|
|
|
|
|
|
list_tokens = []
|
|
list_bbs = []
|
|
list_words = []
|
|
box2token_span_map = []
|
|
lwords = [''] * self.max_seq_length
|
|
|
|
box_to_token_indices = []
|
|
cum_token_idx = 0
|
|
|
|
cls_bbs = [0.0] * 8
|
|
|
|
for word_idx, word in enumerate(json_obj["words"]):
|
|
this_box_token_indices = []
|
|
|
|
tokens = word["layoutxlm_tokens"]
|
|
bb = word["boundingBox"]
|
|
text = word["text"]
|
|
|
|
if len(tokens) == 0:
|
|
tokens.append(self.unk_token_id)
|
|
|
|
if len(list_tokens) + len(tokens) > self.max_seq_length - 2:
|
|
break
|
|
|
|
box2token_span_map.append(
|
|
[len(list_tokens) + 1, len(list_tokens) + len(tokens) + 1]
|
|
) # including st_idx
|
|
list_tokens += tokens
|
|
|
|
# min, max clipping
|
|
for coord_idx in range(4):
|
|
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], width))
|
|
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], height))
|
|
|
|
bb = list(itertools.chain(*bb))
|
|
bbs = [bb for _ in range(len(tokens))]
|
|
texts = [text for _ in range(len(tokens))]
|
|
|
|
for _ in tokens:
|
|
cum_token_idx += 1
|
|
this_box_token_indices.append(cum_token_idx)
|
|
|
|
list_bbs.extend(bbs)
|
|
list_words.extend(texts) ####
|
|
box_to_token_indices.append(this_box_token_indices)
|
|
|
|
sep_bbs = [width, height] * 4
|
|
|
|
# For [CLS] and [SEP]
|
|
list_tokens = (
|
|
[self.cls_token_id_layoutxlm]
|
|
+ list_tokens[: self.max_seq_length - 2]
|
|
+ [self.sep_token_id_layoutxlm]
|
|
)
|
|
if len(list_bbs) == 0:
|
|
# When len(json_obj["words"]) == 0 (no OCR result)
|
|
list_bbs = [cls_bbs] + [sep_bbs]
|
|
else: # len(list_bbs) > 0
|
|
list_bbs = [cls_bbs] + list_bbs[: self.max_seq_length - 2] + [sep_bbs]
|
|
# list_words = ['CLS'] + list_words[: self.max_seq_length - 2] + ['SEP'] ###
|
|
list_words = [self.tokenizer_layoutxlm._cls_token] + list_words[: self.max_seq_length - 2] + [self.tokenizer_layoutxlm._sep_token]
|
|
|
|
|
|
len_list_tokens = len(list_tokens)
|
|
input_ids[:len_list_tokens] = list_tokens
|
|
attention_mask[:len_list_tokens] = 1
|
|
|
|
bbox[:len_list_tokens, :] = list_bbs
|
|
lwords[:len_list_tokens] = list_words
|
|
|
|
# Normalize bbox -> 0 ~ 1
|
|
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / width
|
|
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / height
|
|
|
|
if self.backbone_type in ("layoutlm", "layoutxlm"):
|
|
bbox = bbox[:, [0, 1, 4, 5]]
|
|
bbox = bbox * 1000
|
|
bbox = bbox.astype(int)
|
|
else:
|
|
assert False
|
|
|
|
st_indices = [
|
|
indices[0]
|
|
for indices in box_to_token_indices
|
|
if indices[0] < self.max_seq_length
|
|
]
|
|
are_box_first_tokens[st_indices] = True
|
|
|
|
# Label
|
|
classes_dic = json_obj["parse"]["class"]
|
|
for class_name in self.class_names:
|
|
if class_name == "others":
|
|
continue
|
|
if class_name not in classes_dic:
|
|
continue
|
|
|
|
for word_list in classes_dic[class_name]:
|
|
is_first, last_word_idx = True, -1
|
|
for word_idx in word_list:
|
|
if word_idx >= len(box_to_token_indices):
|
|
break
|
|
box2token_list = box_to_token_indices[word_idx]
|
|
for converted_word_idx in box2token_list:
|
|
if converted_word_idx >= self.max_seq_length:
|
|
break # out of idx
|
|
|
|
if is_first:
|
|
itc_labels[converted_word_idx] = self.class_idx_dic[
|
|
class_name
|
|
]
|
|
is_first, last_word_idx = False, converted_word_idx
|
|
else:
|
|
stc_labels[converted_word_idx] = last_word_idx
|
|
last_word_idx = converted_word_idx
|
|
|
|
# Label
|
|
relations = json_obj["parse"]["relations"]
|
|
for relation in relations:
|
|
if relation[0] >= len(box2token_span_map) or relation[1] >= len(
|
|
box2token_span_map
|
|
):
|
|
continue
|
|
if (
|
|
box2token_span_map[relation[0]][0] >= self.max_seq_length
|
|
or box2token_span_map[relation[1]][0] >= self.max_seq_length
|
|
):
|
|
continue
|
|
|
|
word_from = box2token_span_map[relation[0]][0]
|
|
word_to = box2token_span_map[relation[1]][0]
|
|
# el_labels[word_to] = word_from
|
|
|
|
if el_labels[word_to] != 512 and el_labels_from_key[word_to] != 512:
|
|
continue
|
|
|
|
# if self.second_relations == 1:
|
|
# if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)):
|
|
# el_labels[word_to] = word_from # pair of (header, key) or (header-value)
|
|
# else:
|
|
#### 1st relation => ['key, 'value']
|
|
#### 2st relation => ['header', 'key'or'value']
|
|
if itc_labels[word_from] == 2 and itc_labels[word_to] == 3:
|
|
el_labels_from_key[word_to] = word_from # pair of (key-value)
|
|
if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)):
|
|
el_labels[word_to] = word_from # pair of (header, key) or (header-value)
|
|
|
|
|
|
|
|
input_ids = torch.from_numpy(input_ids)
|
|
bbox = torch.from_numpy(bbox)
|
|
attention_mask = torch.from_numpy(attention_mask)
|
|
|
|
itc_labels = torch.from_numpy(itc_labels)
|
|
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
|
|
stc_labels = torch.from_numpy(stc_labels)
|
|
el_labels = torch.from_numpy(el_labels)
|
|
el_labels_from_key = torch.from_numpy(el_labels_from_key)
|
|
|
|
return_dict = {
|
|
# "image": feature_maps,
|
|
"input_ids": input_ids,
|
|
"bbox": bbox,
|
|
"words": lwords,
|
|
"attention_mask": attention_mask,
|
|
"itc_labels": itc_labels,
|
|
"are_box_first_tokens": are_box_first_tokens,
|
|
"stc_labels": stc_labels,
|
|
"el_labels": el_labels,
|
|
"el_labels_from_key": el_labels_from_key
|
|
}
|
|
|
|
return return_dict
|
|
|
|
|
|
class DocumentKVUProcess(KVUProcess):
|
|
def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, max_window_count, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0):
|
|
super().__init__(tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length, mode)
|
|
self.max_window_count = max_window_count
|
|
self.pad_token_id = self.pad_token_id_layoutxlm
|
|
self.cls_token_id = self.cls_token_id_layoutxlm
|
|
self.sep_token_id = self.sep_token_id_layoutxlm
|
|
self.unk_token_id = self.unk_token_id_layoutxlm
|
|
self.tokenizer = self.tokenizer_layoutxlm
|
|
|
|
def __call__(self, img_path: str, ocr_path: str) -> list:
|
|
if (self.run_ocr == 1) and (not os.path.exists(ocr_path)):
|
|
process_img(img_path, "tmp.txt", self.ocr_engine, export_img=False)
|
|
ocr_path = "tmp.txt"
|
|
lbboxes, lwords = read_ocr_result_from_txt(ocr_path)
|
|
lwords = post_process_basic_ocr(lwords)
|
|
|
|
width, height = imagesize.get(img_path)
|
|
images = [Image.open(img_path).convert("RGB")]
|
|
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
|
|
output = self.preprocess(lbboxes, lwords,
|
|
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path})
|
|
return output
|
|
|
|
def preprocess(self, bounding_boxes, words, feature_maps):
|
|
n_words = len(words)
|
|
output_dicts = {'windows': [], 'documents': []}
|
|
n_empty_windows = 0
|
|
|
|
for i in range(self.max_window_count):
|
|
input_ids = np.ones(self.max_seq_length, dtype=int) * self.pad_token_id
|
|
bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32)
|
|
attention_mask = np.zeros(self.max_seq_length, dtype=int)
|
|
are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_)
|
|
|
|
if n_words == 0:
|
|
n_empty_windows += 1
|
|
output_dicts['windows'].append({
|
|
"image": feature_maps['image'],
|
|
"input_ids": torch.from_numpy(input_ids),
|
|
"bbox": torch.from_numpy(bbox),
|
|
"words": [],
|
|
"attention_mask": torch.from_numpy(attention_mask),
|
|
"are_box_first_tokens": torch.from_numpy(are_box_first_tokens),
|
|
})
|
|
continue
|
|
|
|
start_word_idx = i * self.window_size
|
|
stop_word_idx = min(n_words, (i+1)*self.window_size)
|
|
|
|
if start_word_idx >= stop_word_idx:
|
|
n_empty_windows += 1
|
|
output_dicts['windows'].append(output_dicts['windows'][-1])
|
|
continue
|
|
|
|
list_tokens = []
|
|
list_bbs = []
|
|
list_words = []
|
|
lwords = [''] * self.max_seq_length
|
|
|
|
box_to_token_indices = []
|
|
cum_token_idx = 0
|
|
|
|
cls_bbs = [0.0] * 8
|
|
|
|
for _, (bounding_box, word) in enumerate(zip(bounding_boxes[start_word_idx:stop_word_idx], words[start_word_idx:stop_word_idx])):
|
|
bb = [[bounding_box[0], bounding_box[1]], [bounding_box[2], bounding_box[1]], [bounding_box[2], bounding_box[3]], [bounding_box[0], bounding_box[3]]]
|
|
tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(word))
|
|
|
|
this_box_token_indices = []
|
|
|
|
if len(tokens) == 0:
|
|
tokens.append(self.unk_token_id)
|
|
|
|
if len(list_tokens) + len(tokens) > self.max_seq_length - 2:
|
|
break
|
|
|
|
list_tokens += tokens
|
|
|
|
# min, max clipping
|
|
for coord_idx in range(4):
|
|
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], feature_maps['width']))
|
|
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], feature_maps['height']))
|
|
|
|
bb = list(itertools.chain(*bb))
|
|
bbs = [bb for _ in range(len(tokens))]
|
|
texts = [word for _ in range(len(tokens))]
|
|
|
|
for _ in tokens:
|
|
cum_token_idx += 1
|
|
this_box_token_indices.append(cum_token_idx)
|
|
|
|
list_bbs.extend(bbs)
|
|
list_words.extend(texts) ###
|
|
box_to_token_indices.append(this_box_token_indices)
|
|
|
|
sep_bbs = [feature_maps['width'], feature_maps['height']] * 4
|
|
|
|
# For [CLS] and [SEP]
|
|
list_tokens = (
|
|
[self.cls_token_id]
|
|
+ list_tokens[: self.max_seq_length - 2]
|
|
+ [self.sep_token_id]
|
|
)
|
|
if len(list_bbs) == 0:
|
|
# When len(json_obj["words"]) == 0 (no OCR result)
|
|
list_bbs = [cls_bbs] + [sep_bbs]
|
|
else: # len(list_bbs) > 0
|
|
list_bbs = [cls_bbs] + list_bbs[: self.max_seq_length - 2] + [sep_bbs]
|
|
if len(list_words) < 510:
|
|
list_words.extend(['</p>' for _ in range(510 - len(list_words))])
|
|
list_words = [self.tokenizer._cls_token] + list_words[: self.max_seq_length - 2] + [self.tokenizer._sep_token]
|
|
|
|
len_list_tokens = len(list_tokens)
|
|
input_ids[:len_list_tokens] = list_tokens
|
|
attention_mask[:len_list_tokens] = 1
|
|
|
|
bbox[:len_list_tokens, :] = list_bbs
|
|
lwords[:len_list_tokens] = list_words ###
|
|
|
|
# Normalize bbox -> 0 ~ 1
|
|
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / feature_maps['width']
|
|
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / feature_maps['height']
|
|
|
|
bbox = bbox[:, [0, 1, 4, 5]]
|
|
bbox = bbox * 1000
|
|
bbox = bbox.astype(int)
|
|
|
|
st_indices = [
|
|
indices[0]
|
|
for indices in box_to_token_indices
|
|
if indices[0] < self.max_seq_length
|
|
]
|
|
are_box_first_tokens[st_indices] = True
|
|
|
|
input_ids = torch.from_numpy(input_ids)
|
|
bbox = torch.from_numpy(bbox)
|
|
attention_mask = torch.from_numpy(attention_mask)
|
|
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
|
|
|
|
return_dict = {
|
|
"image": feature_maps['image'],
|
|
"input_ids": input_ids,
|
|
"bbox": bbox,
|
|
"words": list_words,
|
|
"attention_mask": attention_mask,
|
|
"are_box_first_tokens": are_box_first_tokens,
|
|
}
|
|
output_dicts["windows"].append(return_dict)
|
|
|
|
attention_mask = torch.cat([o['attention_mask'] for o in output_dicts["windows"]])
|
|
are_box_first_tokens = torch.cat([o['are_box_first_tokens'] for o in output_dicts["windows"]])
|
|
if n_empty_windows > 0:
|
|
attention_mask[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=int))
|
|
are_box_first_tokens[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=np.bool_))
|
|
bbox = torch.cat([o['bbox'] for o in output_dicts["windows"]])
|
|
words = []
|
|
for o in output_dicts['windows']:
|
|
words.extend(o['words'])
|
|
|
|
return_dict = {
|
|
"attention_mask": attention_mask,
|
|
"bbox": bbox,
|
|
"are_box_first_tokens": are_box_first_tokens,
|
|
"n_empty_windows": n_empty_windows,
|
|
"words": words
|
|
}
|
|
output_dicts['documents'] = return_dict
|
|
|
|
return output_dicts
|
|
|
|
|
|
|