sbt-idp/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/modules/preprocess.py
2023-11-30 18:22:16 +07:00

479 lines
17 KiB
Python

import torch
import itertools
import numpy as np
from sdsvkvu.sources.utils import sliding_windows
class KVUProcessor:
def __init__(
self,
tokenizer_layoutxlm,
feature_extractor,
backbone_type,
class_names,
slice_interval,
window_size,
max_seq_length,
mode,
**kwargs,
):
self.mode = mode
self.class_names = class_names
self.backbone_type = backbone_type
self.window_size = window_size
self.slice_interval = slice_interval
self.max_seq_length = max_seq_length
self.tokenizer_layoutxlm = tokenizer_layoutxlm
self.feature_extractor = feature_extractor
self.pad_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(
tokenizer_layoutxlm._pad_token
)
self.cls_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(
tokenizer_layoutxlm._cls_token
)
self.sep_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(
tokenizer_layoutxlm._sep_token
)
self.unk_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(
self.tokenizer_layoutxlm._unk_token
)
self.class_idx_dic = dict(
[(class_name, idx) for idx, class_name in enumerate(self.class_names)]
)
def __call__(self, lbboxes: list, lwords: list, image, width, height) -> dict:
image = torch.from_numpy(
self.feature_extractor(image)["pixel_values"][0].copy()
)
bbox_windows = sliding_windows(lbboxes, self.window_size, self.slice_interval)
word_windows = sliding_windows(lwords, self.window_size, self.slice_interval)
assert len(bbox_windows) == len(word_windows), \
f"Shape of lbboxes and lwords after sliding window is not the same {len(bbox_windows)} # {len(word_windows)}"
if self.mode == 0: # First 512 tokens
output = self.preprocess_window(
bounding_boxes=lbboxes,
words=lwords,
image_features={"image": image, "width": width, "height": height},
max_seq_length=self.max_seq_length,
)
elif self.mode == 1: # Get full tokens
output = {}
windows = []
for i in range(len(bbox_windows)):
windows.append(
self.preprocess_window(
bounding_boxes=bbox_windows[i],
words=word_windows[i],
image_features={"image": image, "width": width, "height": height},
max_seq_length=self.max_seq_length,
)
)
output["windows"] = windows
elif self.mode == 2: # Sliding window
output = {}
windows = []
output["doduments"] = self.preprocess_window(
bounding_boxes=lbboxes,
words=lwords,
image_features={"image": image, "width": width, "height": height},
max_seq_length=2048,
)
for i in range(len(bbox_windows)):
windows.append(
self.preprocess(
bounding_boxes=bbox_windows[i],
words=word_windows[i],
image_features={"image": image, "width": width, "height": height},
max_seq_length=self.max_seq_length,
)
)
output["windows"] = windows
else:
raise ValueError(f"Not supported mode: {self.mode }")
return output
def preprocess_window(self, bounding_boxes, words, image_features, max_seq_length):
list_word_objects = []
for bb, text in zip(bounding_boxes, words):
boundingBox = [
[bb[0], bb[1]],
[bb[2], bb[1]],
[bb[2], bb[3]],
[bb[0], bb[3]],
]
tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(
self.tokenizer_layoutxlm.tokenize(text)
)
list_word_objects.append(
{"layoutxlm_tokens": tokens, "boundingBox": boundingBox, "text": text}
)
(
bbox,
input_ids,
attention_mask,
are_box_first_tokens,
box_to_token_indices,
box2token_span_map,
lwords,
len_valid_tokens,
len_non_overlap_tokens,
len_list_tokens,
) = self.parser_words(
list_word_objects,
self.max_seq_length,
image_features["width"],
image_features["height"],
)
assert len_list_tokens == len_valid_tokens + 2
len_overlap_tokens = len_valid_tokens - len_non_overlap_tokens
ntokens = max_seq_length if max_seq_length == 512 else len_valid_tokens + 2
input_ids = input_ids[:ntokens]
attention_mask = attention_mask[:ntokens]
bbox = bbox[:ntokens]
are_box_first_tokens = are_box_first_tokens[:ntokens]
input_ids = torch.from_numpy(input_ids)
attention_mask = torch.from_numpy(attention_mask)
bbox = torch.from_numpy(bbox)
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
len_valid_tokens = torch.tensor(len_valid_tokens)
len_overlap_tokens = torch.tensor(len_overlap_tokens)
return_dict = {
"words": lwords,
"len_overlap_tokens": len_overlap_tokens,
"len_valid_tokens": len_valid_tokens,
"image": image_features["image"],
"input_ids_layoutxlm": input_ids,
"attention_mask_layoutxlm": attention_mask,
"are_box_first_tokens": are_box_first_tokens,
"bbox": bbox,
}
return return_dict
def parser_words(self, words, max_seq_length, width, height):
list_bbs = []
list_words = []
list_tokens = []
cls_bbs = [0.0] * 8
box2token_span_map = []
box_to_token_indices = []
lwords = [""] * max_seq_length
cum_token_idx = 0
len_valid_tokens = 0
len_non_overlap_tokens = 0
input_ids = np.ones(max_seq_length, dtype=int) * self.pad_token_id_layoutxlm
bbox = np.zeros((max_seq_length, 8), dtype=np.float32)
attention_mask = np.zeros(max_seq_length, dtype=int)
are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_)
for word_idx, word in enumerate(words):
this_box_token_indices = []
tokens = word["layoutxlm_tokens"]
bb = word["boundingBox"]
text = word["text"]
len_valid_tokens += len(tokens)
if word_idx < self.slice_interval:
len_non_overlap_tokens += len(tokens)
if len(tokens) == 0:
tokens.append(self.unk_token_id)
if len(list_tokens) + len(tokens) > max_seq_length - 2:
break
box2token_span_map.append(
[len(list_tokens) + 1, len(list_tokens) + len(tokens) + 1]
) # including st_idx
list_tokens += tokens
# min, max clipping
for coord_idx in range(4):
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], width))
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], height))
bb = list(itertools.chain(*bb))
bbs = [bb for _ in range(len(tokens))]
texts = [text for _ in range(len(tokens))]
for _ in tokens:
cum_token_idx += 1
this_box_token_indices.append(cum_token_idx)
list_bbs.extend(bbs)
list_words.extend(texts) ####
box_to_token_indices.append(this_box_token_indices)
sep_bbs = [width, height] * 4
# For [CLS] and [SEP]
list_tokens = (
[self.cls_token_id_layoutxlm]
+ list_tokens[: max_seq_length - 2]
+ [self.sep_token_id_layoutxlm]
)
if len(list_bbs) == 0:
# When len(json_obj["words"]) == 0 (no OCR result)
list_bbs = [cls_bbs] + [sep_bbs]
else: # len(list_bbs) > 0
list_bbs = [cls_bbs] + list_bbs[: max_seq_length - 2] + [sep_bbs]
# list_words = ['CLS'] + list_words[: max_seq_length - 2] + ['SEP'] ###
# if len(list_words) < 510:
# list_words.extend(['</p>' for _ in range(510 - len(list_words))])
list_words = (
[self.tokenizer_layoutxlm._cls_token]
+ list_words[: max_seq_length - 2]
+ [self.tokenizer_layoutxlm._sep_token]
)
len_list_tokens = len(list_tokens)
input_ids[:len_list_tokens] = list_tokens
attention_mask[:len_list_tokens] = 1
bbox[:len_list_tokens, :] = list_bbs
lwords[:len_list_tokens] = list_words
# Normalize bbox -> 0 ~ 1
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / width
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / height
if self.backbone_type in ("layoutlm", "layoutxlm"):
bbox = bbox[:, [0, 1, 4, 5]]
bbox = bbox * 1000
bbox = bbox.astype(int)
else:
assert False
st_indices = [
indices[0]
for indices in box_to_token_indices
if indices[0] < max_seq_length
]
are_box_first_tokens[st_indices] = True
return (
bbox,
input_ids,
attention_mask,
are_box_first_tokens,
box_to_token_indices,
box2token_span_map,
lwords,
len_valid_tokens,
len_non_overlap_tokens,
len_list_tokens,
)
class DocKVUProcessor(KVUProcessor):
def __init__(
self,
tokenizer_layoutxlm,
feature_extractor,
backbone_type,
class_names,
max_window_count,
slice_interval,
window_size,
max_seq_length,
mode,
**kwargs,
):
super().__init__(
tokenizer_layoutxlm=tokenizer_layoutxlm,
feature_extractor=feature_extractor,
backbone_type=backbone_type,
class_names=class_names,
slice_interval=slice_interval,
window_size=window_size,
max_seq_length=max_seq_length,
mode=mode,
)
self.max_window_count = max_window_count
self.pad_token_id = self.pad_token_id_layoutxlm
self.cls_token_id = self.cls_token_id_layoutxlm
self.sep_token_id = self.sep_token_id_layoutxlm
self.unk_token_id = self.unk_token_id_layoutxlm
self.tokenizer = self.tokenizer_layoutxlm
def __call__(self, lbboxes: list, lwords: list, images, width, height) -> dict:
image_features = torch.from_numpy(
self.feature_extractor(images)["pixel_values"][0].copy()
)
output = self.preprocess_document(
bounding_boxes=lbboxes,
words=lwords,
image_features={"image": image_features, "width": width, "height": height},
max_seq_length=self.max_seq_length,
)
return output
def preprocess_document(self, bounding_boxes, words, image_features, max_seq_length):
n_words = len(words)
output_dicts = {"windows": [], "documents": []}
n_empty_windows = 0
for i in range(self.max_window_count):
input_ids = np.ones(max_seq_length, dtype=int) * self.pad_token_id
bbox = np.zeros((max_seq_length, 8), dtype=np.float32)
attention_mask = np.zeros(max_seq_length, dtype=int)
are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_)
if n_words == 0:
n_empty_windows += 1
output_dicts["windows"].append(
{
"image": image_features["image"],
"input_ids_layoutxlm": torch.from_numpy(input_ids),
"bbox": torch.from_numpy(bbox),
"words": [],
"attention_mask_layoutxlm": torch.from_numpy(attention_mask),
"are_box_first_tokens": torch.from_numpy(are_box_first_tokens),
}
)
continue
start_word_idx = i * self.window_size
stop_word_idx = min(n_words, (i + 1) * self.window_size)
if start_word_idx >= stop_word_idx:
n_empty_windows += 1
output_dicts["windows"].append(output_dicts["windows"][-1])
continue
list_word_objects = []
for bb, text in zip(
bounding_boxes[start_word_idx:stop_word_idx],
words[start_word_idx:stop_word_idx],
):
boundingBox = [
[bb[0], bb[1]],
[bb[2], bb[1]],
[bb[2], bb[3]],
[bb[0], bb[3]],
]
tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(
self.tokenizer_layoutxlm.tokenize(text)
)
list_word_objects.append(
{
"layoutxlm_tokens": tokens,
"boundingBox": boundingBox,
"text": text,
}
)
(
bbox,
input_ids,
attention_mask,
are_box_first_tokens,
box_to_token_indices,
box2token_span_map,
lwords,
len_valid_tokens,
len_non_overlap_tokens,
len_list_layoutxlm_tokens,
) = self.parser_words(
list_word_objects,
self.max_seq_length,
image_features["width"],
image_features["height"],
)
input_ids = torch.from_numpy(input_ids)
bbox = torch.from_numpy(bbox)
attention_mask = torch.from_numpy(attention_mask)
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
return_dict = {
"bbox": bbox,
"words": lwords,
"image": image_features["image"],
"input_ids_layoutxlm": input_ids,
"attention_mask_layoutxlm": attention_mask,
"are_box_first_tokens": are_box_first_tokens,
}
output_dicts["windows"].append(return_dict)
attention_mask = torch.cat(
[o["attention_mask_layoutxlm"] for o in output_dicts["windows"]]
)
are_box_first_tokens = torch.cat(
[o["are_box_first_tokens"] for o in output_dicts["windows"]]
)
if n_empty_windows > 0:
attention_mask[
self.max_seq_length * (self.max_window_count - n_empty_windows) :
] = torch.from_numpy(
np.zeros(self.max_seq_length * n_empty_windows, dtype=int)
)
are_box_first_tokens[
self.max_seq_length * (self.max_window_count - n_empty_windows) :
] = torch.from_numpy(
np.zeros(self.max_seq_length * n_empty_windows, dtype=np.bool_)
)
bbox = torch.cat([o["bbox"] for o in output_dicts["windows"]])
words = []
for o in output_dicts["windows"]:
words.extend(o["words"])
return_dict = {
"bbox": bbox,
"words": words,
"attention_mask_layoutxlm": attention_mask,
"are_box_first_tokens": are_box_first_tokens,
"n_empty_windows": n_empty_windows,
}
output_dicts["documents"] = return_dict
return output_dicts
class SBTProcessor(DocKVUProcessor):
def __init__(
self,
tokenizer_layoutxlm,
feature_extractor,
backbone_type,
class_names,
max_window_count,
slice_interval,
window_size,
max_seq_length,
mode,
**kwargs,
):
super().__init__(
tokenizer_layoutxlm,
feature_extractor,
backbone_type,
class_names,
max_window_count,
slice_interval,
window_size,
max_seq_length,
mode,
**kwargs,
)
def __call__(self, lbboxes: list, lwords: list, images, width, height) -> dict:
return super().__call__(lbboxes, lwords, images, width, height)