sbt-idp/cope2n-ai-fi/common/AnyKey_Value/preprocess.py

456 lines
19 KiB
Python
Raw Normal View History

2023-11-30 11:22:16 +00:00
import os
from typing import Any
import numpy as np
import pandas as pd
import imagesize
import itertools
from PIL import Image
import argparse
import torch
from utils.utils import read_ocr_result_from_txt, read_json, post_process_basic_ocr
from utils.run_ocr import load_ocr_engine, process_img
from lightning_modules.utils import sliding_windows
class KVUProcess:
def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0):
self.tokenizer_layoutxlm = tokenizer_layoutxlm
self.feature_extractor = feature_extractor
self.max_seq_length = max_seq_length
self.backbone_type = backbone_type
self.class_names = class_names
self.slice_interval = slice_interval
self.window_size = window_size
self.run_ocr = run_ocr
self.mode = mode
self.pad_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._pad_token)
self.cls_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._cls_token)
self.sep_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._sep_token)
self.unk_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._unk_token)
self.class_idx_dic = dict(
[(class_name, idx) for idx, class_name in enumerate(self.class_names)]
)
self.ocr_engine = None
if self.run_ocr == 1:
self.ocr_engine = load_ocr_engine()
def __call__(self, img_path: str, ocr_path: str) -> list:
if (self.run_ocr == 1) or (not os.path.exists(ocr_path)):
ocr_path = "tmp.txt"
process_img(img_path, ocr_path, self.ocr_engine, export_img=False)
lbboxes, lwords = read_ocr_result_from_txt(ocr_path)
lwords = post_process_basic_ocr(lwords)
bbox_windows = sliding_windows(lbboxes, self.window_size, self.slice_interval)
word_windows = sliding_windows(lwords, self.window_size, self.slice_interval)
assert len(bbox_windows) == len(word_windows), f"Shape of lbboxes and lwords after sliding window is not the same {len(bbox_windows)} # {len(word_windows)}"
width, height = imagesize.get(img_path)
images = [Image.open(img_path).convert("RGB")]
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
if self.mode == 0:
output = self.preprocess(lbboxes, lwords,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=self.max_seq_length)
elif self.mode == 1:
output = {}
windows = []
for i in range(len(bbox_windows)):
_words = word_windows[i]
_bboxes = bbox_windows[i]
windows.append(
self.preprocess(
_bboxes, _words,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=self.max_seq_length)
)
output['windows'] = windows
elif self.mode == 2:
output = {}
windows = []
output['doduments'] = self.preprocess(lbboxes, lwords,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=2048)
for i in range(len(bbox_windows)):
_words = word_windows[i]
_bboxes = bbox_windows[i]
windows.append(
self.preprocess(
_bboxes, _words,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=self.max_seq_length)
)
output['windows'] = windows
else:
raise ValueError(
f"Not supported mode: {self.mode }"
)
return output
def preprocess(self, bounding_boxes, words, feature_maps, max_seq_length):
list_word_objects = []
for bb, text in zip(bounding_boxes, words):
boundingBox = [[bb[0], bb[1]], [bb[2], bb[1]], [bb[2], bb[3]], [bb[0], bb[3]]]
tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm.tokenize(text))
list_word_objects.append({
"layoutxlm_tokens": tokens,
"boundingBox": boundingBox,
"text": text
})
(
bbox,
input_ids,
attention_mask,
are_box_first_tokens,
box_to_token_indices,
box2token_span_map,
lwords,
len_valid_tokens,
len_non_overlap_tokens,
len_list_tokens
) = self.parser_words(list_word_objects, self.max_seq_length, feature_maps["width"], feature_maps["height"])
assert len_list_tokens == len_valid_tokens + 2
len_overlap_tokens = len_valid_tokens - len_non_overlap_tokens
ntokens = max_seq_length if max_seq_length == 512 else len_valid_tokens + 2
input_ids = input_ids[:ntokens]
attention_mask = attention_mask[:ntokens]
bbox = bbox[:ntokens]
are_box_first_tokens = are_box_first_tokens[:ntokens]
input_ids = torch.from_numpy(input_ids)
attention_mask = torch.from_numpy(attention_mask)
bbox = torch.from_numpy(bbox)
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
len_valid_tokens = torch.tensor(len_valid_tokens)
len_overlap_tokens = torch.tensor(len_overlap_tokens)
return_dict = {
"img_path": feature_maps['img_path'],
"words": lwords,
"len_overlap_tokens": len_overlap_tokens,
'len_valid_tokens': len_valid_tokens,
"image": feature_maps['image'],
"input_ids_layoutxlm": input_ids,
"attention_mask_layoutxlm": attention_mask,
"are_box_first_tokens": are_box_first_tokens,
"bbox": bbox,
}
return return_dict
def parser_words(self, words, max_seq_length, width, height):
list_bbs = []
list_words = []
list_tokens = []
cls_bbs = [0.0] * 8
box2token_span_map = []
box_to_token_indices = []
lwords = [''] * max_seq_length
cum_token_idx = 0
len_valid_tokens = 0
len_non_overlap_tokens = 0
input_ids = np.ones(max_seq_length, dtype=int) * self.pad_token_id_layoutxlm
bbox = np.zeros((max_seq_length, 8), dtype=np.float32)
attention_mask = np.zeros(max_seq_length, dtype=int)
are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_)
for word_idx, word in enumerate(words):
this_box_token_indices = []
tokens = word["layoutxlm_tokens"]
bb = word["boundingBox"]
text = word["text"]
len_valid_tokens += len(tokens)
if word_idx < self.slice_interval:
len_non_overlap_tokens += len(tokens)
if len(tokens) == 0:
tokens.append(self.unk_token_id)
if len(list_tokens) + len(tokens) > max_seq_length - 2:
break
box2token_span_map.append(
[len(list_tokens) + 1, len(list_tokens) + len(tokens) + 1]
) # including st_idx
list_tokens += tokens
# min, max clipping
for coord_idx in range(4):
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], width))
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], height))
bb = list(itertools.chain(*bb))
bbs = [bb for _ in range(len(tokens))]
texts = [text for _ in range(len(tokens))]
for _ in tokens:
cum_token_idx += 1
this_box_token_indices.append(cum_token_idx)
list_bbs.extend(bbs)
list_words.extend(texts) ####
box_to_token_indices.append(this_box_token_indices)
sep_bbs = [width, height] * 4
# For [CLS] and [SEP]
list_tokens = (
[self.cls_token_id_layoutxlm]
+ list_tokens[: max_seq_length - 2]
+ [self.sep_token_id_layoutxlm]
)
if len(list_bbs) == 0:
# When len(json_obj["words"]) == 0 (no OCR result)
list_bbs = [cls_bbs] + [sep_bbs]
else: # len(list_bbs) > 0
list_bbs = [cls_bbs] + list_bbs[: max_seq_length - 2] + [sep_bbs]
# list_words = ['CLS'] + list_words[: max_seq_length - 2] + ['SEP'] ###
# if len(list_words) < 510:
# list_words.extend(['</p>' for _ in range(510 - len(list_words))])
list_words = [self.tokenizer_layoutxlm._cls_token] + list_words[: max_seq_length - 2] + [self.tokenizer_layoutxlm._sep_token]
len_list_tokens = len(list_tokens)
input_ids[:len_list_tokens] = list_tokens
attention_mask[:len_list_tokens] = 1
bbox[:len_list_tokens, :] = list_bbs
lwords[:len_list_tokens] = list_words
# Normalize bbox -> 0 ~ 1
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / width
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / height
if self.backbone_type in ("layoutlm", "layoutxlm"):
bbox = bbox[:, [0, 1, 4, 5]]
bbox = bbox * 1000
bbox = bbox.astype(int)
else:
assert False
st_indices = [
indices[0]
for indices in box_to_token_indices
if indices[0] < max_seq_length
]
are_box_first_tokens[st_indices] = True
return (
bbox,
input_ids,
attention_mask,
are_box_first_tokens,
box_to_token_indices,
box2token_span_map,
lwords,
len_valid_tokens,
len_non_overlap_tokens,
len_list_tokens
)
def parser_entity_extraction(self, parse_class, box_to_token_indices, max_seq_length):
itc_labels = np.zeros(max_seq_length, dtype=int)
stc_labels = np.ones(max_seq_length, dtype=np.int64) * max_seq_length
classes_dic = parse_class
for class_name in self.class_names:
if class_name == "others":
continue
if class_name not in classes_dic:
continue
for word_list in classes_dic[class_name]:
is_first, last_word_idx = True, -1
for word_idx in word_list:
if word_idx >= len(box_to_token_indices):
break
box2token_list = box_to_token_indices[word_idx]
for converted_word_idx in box2token_list:
if converted_word_idx >= max_seq_length:
break # out of idx
if is_first:
itc_labels[converted_word_idx] = self.class_idx_dic[
class_name
]
is_first, last_word_idx = False, converted_word_idx
else:
stc_labels[converted_word_idx] = last_word_idx
last_word_idx = converted_word_idx
return itc_labels, stc_labels
def parser_entity_linking(self, parse_relation, itc_labels, box2token_span_map, max_seq_length):
el_labels = np.ones(max_seq_length, dtype=int) * max_seq_length
el_labels_from_key = np.ones(max_seq_length, dtype=int) * max_seq_length
relations = parse_relation
for relation in relations:
if relation[0] >= len(box2token_span_map) or relation[1] >= len(
box2token_span_map
):
continue
if (
box2token_span_map[relation[0]][0] >= max_seq_length
or box2token_span_map[relation[1]][0] >= max_seq_length
):
continue
word_from = box2token_span_map[relation[0]][0]
word_to = box2token_span_map[relation[1]][0]
# el_labels[word_to] = word_from
if el_labels[word_to] != 512 and el_labels_from_key[word_to] != 512:
continue
if itc_labels[word_from] == 2 and itc_labels[word_to] == 3:
el_labels_from_key[word_to] = word_from # pair of (key-value)
if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)):
el_labels[word_to] = word_from # pair of (header, key) or (header-value)
return el_labels, el_labels_from_key
class DocumentKVUProcess(KVUProcess):
def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, max_window_count, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0):
super().__init__(tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length, mode)
self.max_window_count = max_window_count
self.pad_token_id = self.pad_token_id_layoutxlm
self.cls_token_id = self.cls_token_id_layoutxlm
self.sep_token_id = self.sep_token_id_layoutxlm
self.unk_token_id = self.unk_token_id_layoutxlm
self.tokenizer = self.tokenizer_layoutxlm
def __call__(self, img_path: str, ocr_path: str) -> list:
if (self.run_ocr == 1) and (not os.path.exists(ocr_path)):
ocr_path = "tmp.txt"
process_img(img_path, ocr_path, self.ocr_engine, export_img=False)
lbboxes, lwords = read_ocr_result_from_txt(ocr_path)
lwords = post_process_basic_ocr(lwords)
width, height = imagesize.get(img_path)
images = [Image.open(img_path).convert("RGB")]
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
output = self.preprocess(lbboxes, lwords,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
self.max_seq_length)
return output
def preprocess(self, bounding_boxes, words, feature_maps, max_seq_length):
n_words = len(words)
output_dicts = {'windows': [], 'documents': []}
n_empty_windows = 0
for i in range(self.max_window_count):
input_ids = np.ones(self.max_seq_length, dtype=int) * self.pad_token_id
bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32)
attention_mask = np.zeros(self.max_seq_length, dtype=int)
are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_)
if n_words == 0:
n_empty_windows += 1
output_dicts['windows'].append({
"image": feature_maps['image'],
"input_ids_layoutxlm": torch.from_numpy(input_ids),
"bbox": torch.from_numpy(bbox),
"words": [],
"attention_mask_layoutxlm": torch.from_numpy(attention_mask),
"are_box_first_tokens": torch.from_numpy(are_box_first_tokens),
})
continue
start_word_idx = i * self.window_size
stop_word_idx = min(n_words, (i+1)*self.window_size)
if start_word_idx >= stop_word_idx:
n_empty_windows += 1
output_dicts['windows'].append(output_dicts['windows'][-1])
continue
list_word_objects = []
for bb, text in zip(bounding_boxes[start_word_idx:stop_word_idx], words[start_word_idx:stop_word_idx]):
boundingBox = [[bb[0], bb[1]], [bb[2], bb[1]], [bb[2], bb[3]], [bb[0], bb[3]]]
tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm.tokenize(text))
list_word_objects.append({
"layoutxlm_tokens": tokens,
"boundingBox": boundingBox,
"text": text
})
(
bbox,
input_ids,
attention_mask,
are_box_first_tokens,
box_to_token_indices,
box2token_span_map,
lwords,
len_valid_tokens,
len_non_overlap_tokens,
len_list_layoutxlm_tokens
) = self.parser_words(list_word_objects, self.max_seq_length, feature_maps["width"], feature_maps["height"])
input_ids = torch.from_numpy(input_ids)
bbox = torch.from_numpy(bbox)
attention_mask = torch.from_numpy(attention_mask)
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
return_dict = {
"image": feature_maps['image'],
"input_ids_layoutxlm": input_ids,
"bbox": bbox,
"words": lwords,
"attention_mask_layoutxlm": attention_mask,
"are_box_first_tokens": are_box_first_tokens,
}
output_dicts["windows"].append(return_dict)
attention_mask = torch.cat([o['attention_mask_layoutxlm'] for o in output_dicts["windows"]])
are_box_first_tokens = torch.cat([o['are_box_first_tokens'] for o in output_dicts["windows"]])
if n_empty_windows > 0:
attention_mask[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=int))
are_box_first_tokens[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=np.bool_))
bbox = torch.cat([o['bbox'] for o in output_dicts["windows"]])
words = []
for o in output_dicts['windows']:
words.extend(o['words'])
return_dict = {
"attention_mask_layoutxlm": attention_mask,
"bbox": bbox,
"are_box_first_tokens": are_box_first_tokens,
"n_empty_windows": n_empty_windows,
"words": words
}
output_dicts['documents'] = return_dict
return output_dicts