sbt-idp/cope2n-ai-fi/common/AnyKey_Value/preprocess.py
2023-11-30 18:22:16 +07:00

456 lines
19 KiB
Python
Executable File

import os
from typing import Any
import numpy as np
import pandas as pd
import imagesize
import itertools
from PIL import Image
import argparse
import torch
from utils.utils import read_ocr_result_from_txt, read_json, post_process_basic_ocr
from utils.run_ocr import load_ocr_engine, process_img
from lightning_modules.utils import sliding_windows
class KVUProcess:
def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0):
self.tokenizer_layoutxlm = tokenizer_layoutxlm
self.feature_extractor = feature_extractor
self.max_seq_length = max_seq_length
self.backbone_type = backbone_type
self.class_names = class_names
self.slice_interval = slice_interval
self.window_size = window_size
self.run_ocr = run_ocr
self.mode = mode
self.pad_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._pad_token)
self.cls_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._cls_token)
self.sep_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._sep_token)
self.unk_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._unk_token)
self.class_idx_dic = dict(
[(class_name, idx) for idx, class_name in enumerate(self.class_names)]
)
self.ocr_engine = None
if self.run_ocr == 1:
self.ocr_engine = load_ocr_engine()
def __call__(self, img_path: str, ocr_path: str) -> list:
if (self.run_ocr == 1) or (not os.path.exists(ocr_path)):
ocr_path = "tmp.txt"
process_img(img_path, ocr_path, self.ocr_engine, export_img=False)
lbboxes, lwords = read_ocr_result_from_txt(ocr_path)
lwords = post_process_basic_ocr(lwords)
bbox_windows = sliding_windows(lbboxes, self.window_size, self.slice_interval)
word_windows = sliding_windows(lwords, self.window_size, self.slice_interval)
assert len(bbox_windows) == len(word_windows), f"Shape of lbboxes and lwords after sliding window is not the same {len(bbox_windows)} # {len(word_windows)}"
width, height = imagesize.get(img_path)
images = [Image.open(img_path).convert("RGB")]
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
if self.mode == 0:
output = self.preprocess(lbboxes, lwords,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=self.max_seq_length)
elif self.mode == 1:
output = {}
windows = []
for i in range(len(bbox_windows)):
_words = word_windows[i]
_bboxes = bbox_windows[i]
windows.append(
self.preprocess(
_bboxes, _words,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=self.max_seq_length)
)
output['windows'] = windows
elif self.mode == 2:
output = {}
windows = []
output['doduments'] = self.preprocess(lbboxes, lwords,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=2048)
for i in range(len(bbox_windows)):
_words = word_windows[i]
_bboxes = bbox_windows[i]
windows.append(
self.preprocess(
_bboxes, _words,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=self.max_seq_length)
)
output['windows'] = windows
else:
raise ValueError(
f"Not supported mode: {self.mode }"
)
return output
def preprocess(self, bounding_boxes, words, feature_maps, max_seq_length):
list_word_objects = []
for bb, text in zip(bounding_boxes, words):
boundingBox = [[bb[0], bb[1]], [bb[2], bb[1]], [bb[2], bb[3]], [bb[0], bb[3]]]
tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm.tokenize(text))
list_word_objects.append({
"layoutxlm_tokens": tokens,
"boundingBox": boundingBox,
"text": text
})
(
bbox,
input_ids,
attention_mask,
are_box_first_tokens,
box_to_token_indices,
box2token_span_map,
lwords,
len_valid_tokens,
len_non_overlap_tokens,
len_list_tokens
) = self.parser_words(list_word_objects, self.max_seq_length, feature_maps["width"], feature_maps["height"])
assert len_list_tokens == len_valid_tokens + 2
len_overlap_tokens = len_valid_tokens - len_non_overlap_tokens
ntokens = max_seq_length if max_seq_length == 512 else len_valid_tokens + 2
input_ids = input_ids[:ntokens]
attention_mask = attention_mask[:ntokens]
bbox = bbox[:ntokens]
are_box_first_tokens = are_box_first_tokens[:ntokens]
input_ids = torch.from_numpy(input_ids)
attention_mask = torch.from_numpy(attention_mask)
bbox = torch.from_numpy(bbox)
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
len_valid_tokens = torch.tensor(len_valid_tokens)
len_overlap_tokens = torch.tensor(len_overlap_tokens)
return_dict = {
"img_path": feature_maps['img_path'],
"words": lwords,
"len_overlap_tokens": len_overlap_tokens,
'len_valid_tokens': len_valid_tokens,
"image": feature_maps['image'],
"input_ids_layoutxlm": input_ids,
"attention_mask_layoutxlm": attention_mask,
"are_box_first_tokens": are_box_first_tokens,
"bbox": bbox,
}
return return_dict
def parser_words(self, words, max_seq_length, width, height):
list_bbs = []
list_words = []
list_tokens = []
cls_bbs = [0.0] * 8
box2token_span_map = []
box_to_token_indices = []
lwords = [''] * max_seq_length
cum_token_idx = 0
len_valid_tokens = 0
len_non_overlap_tokens = 0
input_ids = np.ones(max_seq_length, dtype=int) * self.pad_token_id_layoutxlm
bbox = np.zeros((max_seq_length, 8), dtype=np.float32)
attention_mask = np.zeros(max_seq_length, dtype=int)
are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_)
for word_idx, word in enumerate(words):
this_box_token_indices = []
tokens = word["layoutxlm_tokens"]
bb = word["boundingBox"]
text = word["text"]
len_valid_tokens += len(tokens)
if word_idx < self.slice_interval:
len_non_overlap_tokens += len(tokens)
if len(tokens) == 0:
tokens.append(self.unk_token_id)
if len(list_tokens) + len(tokens) > max_seq_length - 2:
break
box2token_span_map.append(
[len(list_tokens) + 1, len(list_tokens) + len(tokens) + 1]
) # including st_idx
list_tokens += tokens
# min, max clipping
for coord_idx in range(4):
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], width))
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], height))
bb = list(itertools.chain(*bb))
bbs = [bb for _ in range(len(tokens))]
texts = [text for _ in range(len(tokens))]
for _ in tokens:
cum_token_idx += 1
this_box_token_indices.append(cum_token_idx)
list_bbs.extend(bbs)
list_words.extend(texts) ####
box_to_token_indices.append(this_box_token_indices)
sep_bbs = [width, height] * 4
# For [CLS] and [SEP]
list_tokens = (
[self.cls_token_id_layoutxlm]
+ list_tokens[: max_seq_length - 2]
+ [self.sep_token_id_layoutxlm]
)
if len(list_bbs) == 0:
# When len(json_obj["words"]) == 0 (no OCR result)
list_bbs = [cls_bbs] + [sep_bbs]
else: # len(list_bbs) > 0
list_bbs = [cls_bbs] + list_bbs[: max_seq_length - 2] + [sep_bbs]
# list_words = ['CLS'] + list_words[: max_seq_length - 2] + ['SEP'] ###
# if len(list_words) < 510:
# list_words.extend(['</p>' for _ in range(510 - len(list_words))])
list_words = [self.tokenizer_layoutxlm._cls_token] + list_words[: max_seq_length - 2] + [self.tokenizer_layoutxlm._sep_token]
len_list_tokens = len(list_tokens)
input_ids[:len_list_tokens] = list_tokens
attention_mask[:len_list_tokens] = 1
bbox[:len_list_tokens, :] = list_bbs
lwords[:len_list_tokens] = list_words
# Normalize bbox -> 0 ~ 1
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / width
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / height
if self.backbone_type in ("layoutlm", "layoutxlm"):
bbox = bbox[:, [0, 1, 4, 5]]
bbox = bbox * 1000
bbox = bbox.astype(int)
else:
assert False
st_indices = [
indices[0]
for indices in box_to_token_indices
if indices[0] < max_seq_length
]
are_box_first_tokens[st_indices] = True
return (
bbox,
input_ids,
attention_mask,
are_box_first_tokens,
box_to_token_indices,
box2token_span_map,
lwords,
len_valid_tokens,
len_non_overlap_tokens,
len_list_tokens
)
def parser_entity_extraction(self, parse_class, box_to_token_indices, max_seq_length):
itc_labels = np.zeros(max_seq_length, dtype=int)
stc_labels = np.ones(max_seq_length, dtype=np.int64) * max_seq_length
classes_dic = parse_class
for class_name in self.class_names:
if class_name == "others":
continue
if class_name not in classes_dic:
continue
for word_list in classes_dic[class_name]:
is_first, last_word_idx = True, -1
for word_idx in word_list:
if word_idx >= len(box_to_token_indices):
break
box2token_list = box_to_token_indices[word_idx]
for converted_word_idx in box2token_list:
if converted_word_idx >= max_seq_length:
break # out of idx
if is_first:
itc_labels[converted_word_idx] = self.class_idx_dic[
class_name
]
is_first, last_word_idx = False, converted_word_idx
else:
stc_labels[converted_word_idx] = last_word_idx
last_word_idx = converted_word_idx
return itc_labels, stc_labels
def parser_entity_linking(self, parse_relation, itc_labels, box2token_span_map, max_seq_length):
el_labels = np.ones(max_seq_length, dtype=int) * max_seq_length
el_labels_from_key = np.ones(max_seq_length, dtype=int) * max_seq_length
relations = parse_relation
for relation in relations:
if relation[0] >= len(box2token_span_map) or relation[1] >= len(
box2token_span_map
):
continue
if (
box2token_span_map[relation[0]][0] >= max_seq_length
or box2token_span_map[relation[1]][0] >= max_seq_length
):
continue
word_from = box2token_span_map[relation[0]][0]
word_to = box2token_span_map[relation[1]][0]
# el_labels[word_to] = word_from
if el_labels[word_to] != 512 and el_labels_from_key[word_to] != 512:
continue
if itc_labels[word_from] == 2 and itc_labels[word_to] == 3:
el_labels_from_key[word_to] = word_from # pair of (key-value)
if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)):
el_labels[word_to] = word_from # pair of (header, key) or (header-value)
return el_labels, el_labels_from_key
class DocumentKVUProcess(KVUProcess):
def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, max_window_count, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0):
super().__init__(tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length, mode)
self.max_window_count = max_window_count
self.pad_token_id = self.pad_token_id_layoutxlm
self.cls_token_id = self.cls_token_id_layoutxlm
self.sep_token_id = self.sep_token_id_layoutxlm
self.unk_token_id = self.unk_token_id_layoutxlm
self.tokenizer = self.tokenizer_layoutxlm
def __call__(self, img_path: str, ocr_path: str) -> list:
if (self.run_ocr == 1) and (not os.path.exists(ocr_path)):
ocr_path = "tmp.txt"
process_img(img_path, ocr_path, self.ocr_engine, export_img=False)
lbboxes, lwords = read_ocr_result_from_txt(ocr_path)
lwords = post_process_basic_ocr(lwords)
width, height = imagesize.get(img_path)
images = [Image.open(img_path).convert("RGB")]
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
output = self.preprocess(lbboxes, lwords,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
self.max_seq_length)
return output
def preprocess(self, bounding_boxes, words, feature_maps, max_seq_length):
n_words = len(words)
output_dicts = {'windows': [], 'documents': []}
n_empty_windows = 0
for i in range(self.max_window_count):
input_ids = np.ones(self.max_seq_length, dtype=int) * self.pad_token_id
bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32)
attention_mask = np.zeros(self.max_seq_length, dtype=int)
are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_)
if n_words == 0:
n_empty_windows += 1
output_dicts['windows'].append({
"image": feature_maps['image'],
"input_ids_layoutxlm": torch.from_numpy(input_ids),
"bbox": torch.from_numpy(bbox),
"words": [],
"attention_mask_layoutxlm": torch.from_numpy(attention_mask),
"are_box_first_tokens": torch.from_numpy(are_box_first_tokens),
})
continue
start_word_idx = i * self.window_size
stop_word_idx = min(n_words, (i+1)*self.window_size)
if start_word_idx >= stop_word_idx:
n_empty_windows += 1
output_dicts['windows'].append(output_dicts['windows'][-1])
continue
list_word_objects = []
for bb, text in zip(bounding_boxes[start_word_idx:stop_word_idx], words[start_word_idx:stop_word_idx]):
boundingBox = [[bb[0], bb[1]], [bb[2], bb[1]], [bb[2], bb[3]], [bb[0], bb[3]]]
tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm.tokenize(text))
list_word_objects.append({
"layoutxlm_tokens": tokens,
"boundingBox": boundingBox,
"text": text
})
(
bbox,
input_ids,
attention_mask,
are_box_first_tokens,
box_to_token_indices,
box2token_span_map,
lwords,
len_valid_tokens,
len_non_overlap_tokens,
len_list_layoutxlm_tokens
) = self.parser_words(list_word_objects, self.max_seq_length, feature_maps["width"], feature_maps["height"])
input_ids = torch.from_numpy(input_ids)
bbox = torch.from_numpy(bbox)
attention_mask = torch.from_numpy(attention_mask)
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
return_dict = {
"image": feature_maps['image'],
"input_ids_layoutxlm": input_ids,
"bbox": bbox,
"words": lwords,
"attention_mask_layoutxlm": attention_mask,
"are_box_first_tokens": are_box_first_tokens,
}
output_dicts["windows"].append(return_dict)
attention_mask = torch.cat([o['attention_mask_layoutxlm'] for o in output_dicts["windows"]])
are_box_first_tokens = torch.cat([o['are_box_first_tokens'] for o in output_dicts["windows"]])
if n_empty_windows > 0:
attention_mask[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=int))
are_box_first_tokens[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=np.bool_))
bbox = torch.cat([o['bbox'] for o in output_dicts["windows"]])
words = []
for o in output_dicts['windows']:
words.extend(o['words'])
return_dict = {
"attention_mask_layoutxlm": attention_mask,
"bbox": bbox,
"are_box_first_tokens": are_box_first_tokens,
"n_empty_windows": n_empty_windows,
"words": words
}
output_dicts['documents'] = return_dict
return output_dicts