479 lines
17 KiB
Python
479 lines
17 KiB
Python
import torch
|
|
import itertools
|
|
import numpy as np
|
|
|
|
from sdsvkvu.sources.utils import sliding_windows
|
|
|
|
|
|
class KVUProcessor:
|
|
def __init__(
|
|
self,
|
|
tokenizer_layoutxlm,
|
|
feature_extractor,
|
|
backbone_type,
|
|
class_names,
|
|
slice_interval,
|
|
window_size,
|
|
max_seq_length,
|
|
mode,
|
|
**kwargs,
|
|
):
|
|
self.mode = mode
|
|
self.class_names = class_names
|
|
self.backbone_type = backbone_type
|
|
|
|
self.window_size = window_size
|
|
self.slice_interval = slice_interval
|
|
self.max_seq_length = max_seq_length
|
|
|
|
self.tokenizer_layoutxlm = tokenizer_layoutxlm
|
|
self.feature_extractor = feature_extractor
|
|
|
|
self.pad_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(
|
|
tokenizer_layoutxlm._pad_token
|
|
)
|
|
self.cls_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(
|
|
tokenizer_layoutxlm._cls_token
|
|
)
|
|
self.sep_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(
|
|
tokenizer_layoutxlm._sep_token
|
|
)
|
|
self.unk_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(
|
|
self.tokenizer_layoutxlm._unk_token
|
|
)
|
|
|
|
self.class_idx_dic = dict(
|
|
[(class_name, idx) for idx, class_name in enumerate(self.class_names)]
|
|
)
|
|
|
|
def __call__(self, lbboxes: list, lwords: list, image, width, height) -> dict:
|
|
image = torch.from_numpy(
|
|
self.feature_extractor(image)["pixel_values"][0].copy()
|
|
)
|
|
|
|
bbox_windows = sliding_windows(lbboxes, self.window_size, self.slice_interval)
|
|
word_windows = sliding_windows(lwords, self.window_size, self.slice_interval)
|
|
assert len(bbox_windows) == len(word_windows), \
|
|
f"Shape of lbboxes and lwords after sliding window is not the same {len(bbox_windows)} # {len(word_windows)}"
|
|
|
|
if self.mode == 0: # First 512 tokens
|
|
output = self.preprocess_window(
|
|
bounding_boxes=lbboxes,
|
|
words=lwords,
|
|
image_features={"image": image, "width": width, "height": height},
|
|
max_seq_length=self.max_seq_length,
|
|
)
|
|
elif self.mode == 1: # Get full tokens
|
|
output = {}
|
|
windows = []
|
|
for i in range(len(bbox_windows)):
|
|
windows.append(
|
|
self.preprocess_window(
|
|
bounding_boxes=bbox_windows[i],
|
|
words=word_windows[i],
|
|
image_features={"image": image, "width": width, "height": height},
|
|
max_seq_length=self.max_seq_length,
|
|
)
|
|
)
|
|
|
|
output["windows"] = windows
|
|
elif self.mode == 2: # Sliding window
|
|
output = {}
|
|
windows = []
|
|
output["doduments"] = self.preprocess_window(
|
|
bounding_boxes=lbboxes,
|
|
words=lwords,
|
|
image_features={"image": image, "width": width, "height": height},
|
|
max_seq_length=2048,
|
|
)
|
|
for i in range(len(bbox_windows)):
|
|
windows.append(
|
|
self.preprocess(
|
|
bounding_boxes=bbox_windows[i],
|
|
words=word_windows[i],
|
|
image_features={"image": image, "width": width, "height": height},
|
|
max_seq_length=self.max_seq_length,
|
|
)
|
|
)
|
|
|
|
output["windows"] = windows
|
|
else:
|
|
raise ValueError(f"Not supported mode: {self.mode }")
|
|
return output
|
|
|
|
def preprocess_window(self, bounding_boxes, words, image_features, max_seq_length):
|
|
list_word_objects = []
|
|
for bb, text in zip(bounding_boxes, words):
|
|
boundingBox = [
|
|
[bb[0], bb[1]],
|
|
[bb[2], bb[1]],
|
|
[bb[2], bb[3]],
|
|
[bb[0], bb[3]],
|
|
]
|
|
tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(
|
|
self.tokenizer_layoutxlm.tokenize(text)
|
|
)
|
|
list_word_objects.append(
|
|
{"layoutxlm_tokens": tokens, "boundingBox": boundingBox, "text": text}
|
|
)
|
|
|
|
(
|
|
bbox,
|
|
input_ids,
|
|
attention_mask,
|
|
are_box_first_tokens,
|
|
box_to_token_indices,
|
|
box2token_span_map,
|
|
lwords,
|
|
len_valid_tokens,
|
|
len_non_overlap_tokens,
|
|
len_list_tokens,
|
|
) = self.parser_words(
|
|
list_word_objects,
|
|
self.max_seq_length,
|
|
image_features["width"],
|
|
image_features["height"],
|
|
)
|
|
|
|
assert len_list_tokens == len_valid_tokens + 2
|
|
len_overlap_tokens = len_valid_tokens - len_non_overlap_tokens
|
|
|
|
ntokens = max_seq_length if max_seq_length == 512 else len_valid_tokens + 2
|
|
|
|
input_ids = input_ids[:ntokens]
|
|
attention_mask = attention_mask[:ntokens]
|
|
bbox = bbox[:ntokens]
|
|
are_box_first_tokens = are_box_first_tokens[:ntokens]
|
|
|
|
input_ids = torch.from_numpy(input_ids)
|
|
attention_mask = torch.from_numpy(attention_mask)
|
|
bbox = torch.from_numpy(bbox)
|
|
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
|
|
|
|
len_valid_tokens = torch.tensor(len_valid_tokens)
|
|
len_overlap_tokens = torch.tensor(len_overlap_tokens)
|
|
return_dict = {
|
|
"words": lwords,
|
|
"len_overlap_tokens": len_overlap_tokens,
|
|
"len_valid_tokens": len_valid_tokens,
|
|
"image": image_features["image"],
|
|
"input_ids_layoutxlm": input_ids,
|
|
"attention_mask_layoutxlm": attention_mask,
|
|
"are_box_first_tokens": are_box_first_tokens,
|
|
"bbox": bbox,
|
|
}
|
|
return return_dict
|
|
|
|
def parser_words(self, words, max_seq_length, width, height):
|
|
list_bbs = []
|
|
list_words = []
|
|
list_tokens = []
|
|
cls_bbs = [0.0] * 8
|
|
box2token_span_map = []
|
|
box_to_token_indices = []
|
|
lwords = [""] * max_seq_length
|
|
|
|
cum_token_idx = 0
|
|
len_valid_tokens = 0
|
|
len_non_overlap_tokens = 0
|
|
|
|
input_ids = np.ones(max_seq_length, dtype=int) * self.pad_token_id_layoutxlm
|
|
bbox = np.zeros((max_seq_length, 8), dtype=np.float32)
|
|
attention_mask = np.zeros(max_seq_length, dtype=int)
|
|
are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_)
|
|
|
|
for word_idx, word in enumerate(words):
|
|
this_box_token_indices = []
|
|
|
|
tokens = word["layoutxlm_tokens"]
|
|
bb = word["boundingBox"]
|
|
text = word["text"]
|
|
|
|
len_valid_tokens += len(tokens)
|
|
if word_idx < self.slice_interval:
|
|
len_non_overlap_tokens += len(tokens)
|
|
|
|
if len(tokens) == 0:
|
|
tokens.append(self.unk_token_id)
|
|
|
|
if len(list_tokens) + len(tokens) > max_seq_length - 2:
|
|
break
|
|
|
|
box2token_span_map.append(
|
|
[len(list_tokens) + 1, len(list_tokens) + len(tokens) + 1]
|
|
) # including st_idx
|
|
list_tokens += tokens
|
|
|
|
# min, max clipping
|
|
for coord_idx in range(4):
|
|
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], width))
|
|
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], height))
|
|
|
|
bb = list(itertools.chain(*bb))
|
|
bbs = [bb for _ in range(len(tokens))]
|
|
texts = [text for _ in range(len(tokens))]
|
|
|
|
for _ in tokens:
|
|
cum_token_idx += 1
|
|
this_box_token_indices.append(cum_token_idx)
|
|
|
|
list_bbs.extend(bbs)
|
|
list_words.extend(texts) ####
|
|
box_to_token_indices.append(this_box_token_indices)
|
|
|
|
sep_bbs = [width, height] * 4
|
|
|
|
# For [CLS] and [SEP]
|
|
list_tokens = (
|
|
[self.cls_token_id_layoutxlm]
|
|
+ list_tokens[: max_seq_length - 2]
|
|
+ [self.sep_token_id_layoutxlm]
|
|
)
|
|
if len(list_bbs) == 0:
|
|
# When len(json_obj["words"]) == 0 (no OCR result)
|
|
list_bbs = [cls_bbs] + [sep_bbs]
|
|
else: # len(list_bbs) > 0
|
|
list_bbs = [cls_bbs] + list_bbs[: max_seq_length - 2] + [sep_bbs]
|
|
# list_words = ['CLS'] + list_words[: max_seq_length - 2] + ['SEP'] ###
|
|
# if len(list_words) < 510:
|
|
# list_words.extend(['</p>' for _ in range(510 - len(list_words))])
|
|
list_words = (
|
|
[self.tokenizer_layoutxlm._cls_token]
|
|
+ list_words[: max_seq_length - 2]
|
|
+ [self.tokenizer_layoutxlm._sep_token]
|
|
)
|
|
|
|
len_list_tokens = len(list_tokens)
|
|
input_ids[:len_list_tokens] = list_tokens
|
|
attention_mask[:len_list_tokens] = 1
|
|
|
|
bbox[:len_list_tokens, :] = list_bbs
|
|
lwords[:len_list_tokens] = list_words
|
|
|
|
# Normalize bbox -> 0 ~ 1
|
|
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / width
|
|
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / height
|
|
|
|
if self.backbone_type in ("layoutlm", "layoutxlm"):
|
|
bbox = bbox[:, [0, 1, 4, 5]]
|
|
bbox = bbox * 1000
|
|
bbox = bbox.astype(int)
|
|
else:
|
|
assert False
|
|
|
|
st_indices = [
|
|
indices[0]
|
|
for indices in box_to_token_indices
|
|
if indices[0] < max_seq_length
|
|
]
|
|
are_box_first_tokens[st_indices] = True
|
|
|
|
return (
|
|
bbox,
|
|
input_ids,
|
|
attention_mask,
|
|
are_box_first_tokens,
|
|
box_to_token_indices,
|
|
box2token_span_map,
|
|
lwords,
|
|
len_valid_tokens,
|
|
len_non_overlap_tokens,
|
|
len_list_tokens,
|
|
)
|
|
|
|
|
|
class DocKVUProcessor(KVUProcessor):
|
|
def __init__(
|
|
self,
|
|
tokenizer_layoutxlm,
|
|
feature_extractor,
|
|
backbone_type,
|
|
class_names,
|
|
max_window_count,
|
|
slice_interval,
|
|
window_size,
|
|
max_seq_length,
|
|
mode,
|
|
**kwargs,
|
|
):
|
|
super().__init__(
|
|
tokenizer_layoutxlm=tokenizer_layoutxlm,
|
|
feature_extractor=feature_extractor,
|
|
backbone_type=backbone_type,
|
|
class_names=class_names,
|
|
slice_interval=slice_interval,
|
|
window_size=window_size,
|
|
max_seq_length=max_seq_length,
|
|
mode=mode,
|
|
)
|
|
|
|
self.max_window_count = max_window_count
|
|
self.pad_token_id = self.pad_token_id_layoutxlm
|
|
self.cls_token_id = self.cls_token_id_layoutxlm
|
|
self.sep_token_id = self.sep_token_id_layoutxlm
|
|
self.unk_token_id = self.unk_token_id_layoutxlm
|
|
self.tokenizer = self.tokenizer_layoutxlm
|
|
|
|
def __call__(self, lbboxes: list, lwords: list, images, width, height) -> dict:
|
|
image_features = torch.from_numpy(
|
|
self.feature_extractor(images)["pixel_values"][0].copy()
|
|
)
|
|
output = self.preprocess_document(
|
|
bounding_boxes=lbboxes,
|
|
words=lwords,
|
|
image_features={"image": image_features, "width": width, "height": height},
|
|
max_seq_length=self.max_seq_length,
|
|
)
|
|
return output
|
|
|
|
def preprocess_document(self, bounding_boxes, words, image_features, max_seq_length):
|
|
n_words = len(words)
|
|
output_dicts = {"windows": [], "documents": []}
|
|
n_empty_windows = 0
|
|
|
|
for i in range(self.max_window_count):
|
|
input_ids = np.ones(max_seq_length, dtype=int) * self.pad_token_id
|
|
bbox = np.zeros((max_seq_length, 8), dtype=np.float32)
|
|
attention_mask = np.zeros(max_seq_length, dtype=int)
|
|
are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_)
|
|
|
|
if n_words == 0:
|
|
n_empty_windows += 1
|
|
output_dicts["windows"].append(
|
|
{
|
|
"image": image_features["image"],
|
|
"input_ids_layoutxlm": torch.from_numpy(input_ids),
|
|
"bbox": torch.from_numpy(bbox),
|
|
"words": [],
|
|
"attention_mask_layoutxlm": torch.from_numpy(attention_mask),
|
|
"are_box_first_tokens": torch.from_numpy(are_box_first_tokens),
|
|
}
|
|
)
|
|
continue
|
|
|
|
start_word_idx = i * self.window_size
|
|
stop_word_idx = min(n_words, (i + 1) * self.window_size)
|
|
|
|
if start_word_idx >= stop_word_idx:
|
|
n_empty_windows += 1
|
|
output_dicts["windows"].append(output_dicts["windows"][-1])
|
|
continue
|
|
|
|
list_word_objects = []
|
|
for bb, text in zip(
|
|
bounding_boxes[start_word_idx:stop_word_idx],
|
|
words[start_word_idx:stop_word_idx],
|
|
):
|
|
boundingBox = [
|
|
[bb[0], bb[1]],
|
|
[bb[2], bb[1]],
|
|
[bb[2], bb[3]],
|
|
[bb[0], bb[3]],
|
|
]
|
|
tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(
|
|
self.tokenizer_layoutxlm.tokenize(text)
|
|
)
|
|
list_word_objects.append(
|
|
{
|
|
"layoutxlm_tokens": tokens,
|
|
"boundingBox": boundingBox,
|
|
"text": text,
|
|
}
|
|
)
|
|
|
|
(
|
|
bbox,
|
|
input_ids,
|
|
attention_mask,
|
|
are_box_first_tokens,
|
|
box_to_token_indices,
|
|
box2token_span_map,
|
|
lwords,
|
|
len_valid_tokens,
|
|
len_non_overlap_tokens,
|
|
len_list_layoutxlm_tokens,
|
|
) = self.parser_words(
|
|
list_word_objects,
|
|
self.max_seq_length,
|
|
image_features["width"],
|
|
image_features["height"],
|
|
)
|
|
|
|
input_ids = torch.from_numpy(input_ids)
|
|
bbox = torch.from_numpy(bbox)
|
|
attention_mask = torch.from_numpy(attention_mask)
|
|
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
|
|
|
|
return_dict = {
|
|
"bbox": bbox,
|
|
"words": lwords,
|
|
"image": image_features["image"],
|
|
"input_ids_layoutxlm": input_ids,
|
|
"attention_mask_layoutxlm": attention_mask,
|
|
"are_box_first_tokens": are_box_first_tokens,
|
|
}
|
|
output_dicts["windows"].append(return_dict)
|
|
|
|
attention_mask = torch.cat(
|
|
[o["attention_mask_layoutxlm"] for o in output_dicts["windows"]]
|
|
)
|
|
are_box_first_tokens = torch.cat(
|
|
[o["are_box_first_tokens"] for o in output_dicts["windows"]]
|
|
)
|
|
if n_empty_windows > 0:
|
|
attention_mask[
|
|
self.max_seq_length * (self.max_window_count - n_empty_windows) :
|
|
] = torch.from_numpy(
|
|
np.zeros(self.max_seq_length * n_empty_windows, dtype=int)
|
|
)
|
|
are_box_first_tokens[
|
|
self.max_seq_length * (self.max_window_count - n_empty_windows) :
|
|
] = torch.from_numpy(
|
|
np.zeros(self.max_seq_length * n_empty_windows, dtype=np.bool_)
|
|
)
|
|
bbox = torch.cat([o["bbox"] for o in output_dicts["windows"]])
|
|
words = []
|
|
for o in output_dicts["windows"]:
|
|
words.extend(o["words"])
|
|
|
|
return_dict = {
|
|
"bbox": bbox,
|
|
"words": words,
|
|
"attention_mask_layoutxlm": attention_mask,
|
|
"are_box_first_tokens": are_box_first_tokens,
|
|
"n_empty_windows": n_empty_windows,
|
|
}
|
|
output_dicts["documents"] = return_dict
|
|
|
|
return output_dicts
|
|
|
|
|
|
class SBTProcessor(DocKVUProcessor):
|
|
def __init__(
|
|
self,
|
|
tokenizer_layoutxlm,
|
|
feature_extractor,
|
|
backbone_type,
|
|
class_names,
|
|
max_window_count,
|
|
slice_interval,
|
|
window_size,
|
|
max_seq_length,
|
|
mode,
|
|
**kwargs,
|
|
):
|
|
super().__init__(
|
|
tokenizer_layoutxlm,
|
|
feature_extractor,
|
|
backbone_type,
|
|
class_names,
|
|
max_window_count,
|
|
slice_interval,
|
|
window_size,
|
|
max_seq_length,
|
|
mode,
|
|
**kwargs,
|
|
)
|
|
|
|
def __call__(self, lbboxes: list, lwords: list, images, width, height) -> dict:
|
|
return super().__call__(lbboxes, lwords, images, width, height) |