sbt-idp/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/preprocess.py
2023-11-30 18:22:16 +07:00

601 lines
24 KiB
Python
Executable File

import os
from typing import Any
import numpy as np
import pandas as pd
import imagesize
import itertools
from PIL import Image
import argparse
import torch
from utils.utils import read_ocr_result_from_txt, read_json, post_process_basic_ocr
from utils.run_ocr import load_ocr_engine, process_img
from lightning_modules.utils import sliding_windows
class KVUProcess:
def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0):
self.tokenizer_layoutxlm = tokenizer_layoutxlm
self.feature_extractor = feature_extractor
self.max_seq_length = max_seq_length
self.backbone_type = backbone_type
self.class_names = class_names
self.slice_interval = slice_interval
self.window_size = window_size
self.run_ocr = run_ocr
self.mode = mode
self.pad_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._pad_token)
self.cls_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._cls_token)
self.sep_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._sep_token)
self.unk_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._unk_token)
self.class_idx_dic = dict(
[(class_name, idx) for idx, class_name in enumerate(self.class_names)]
)
self.ocr_engine = None
if self.run_ocr == 1:
self.ocr_engine = load_ocr_engine()
def __call__(self, img_path: str, ocr_path: str) -> list:
if (self.run_ocr == 1) or (not os.path.exists(ocr_path)):
process_img(img_path, "tmp.txt", self.ocr_engine, export_img=False)
ocr_path = "tmp.txt"
lbboxes, lwords = read_ocr_result_from_txt(ocr_path)
lwords = post_process_basic_ocr(lwords)
bbox_windows = sliding_windows(lbboxes, self.window_size, self.slice_interval)
word_windows = sliding_windows(lwords, self.window_size, self.slice_interval)
assert len(bbox_windows) == len(word_windows), f"Shape of lbboxes and lwords after sliding window is not the same {len(bbox_windows)} # {len(word_windows)}"
width, height = imagesize.get(img_path)
images = [Image.open(img_path).convert("RGB")]
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
if self.mode == 0:
output = self.preprocess(lbboxes, lwords,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=self.max_seq_length)
elif self.mode == 1:
output = {}
windows = []
for i in range(len(bbox_windows)):
_words = word_windows[i]
_bboxes = bbox_windows[i]
windows.append(
self.preprocess(
_bboxes, _words,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=self.max_seq_length)
)
output['windows'] = windows
elif self.mode == 2:
output = {}
windows = []
output['doduments'] = self.preprocess(lbboxes, lwords,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=2048)
for i in range(len(bbox_windows)):
_words = word_windows[i]
_bboxes = bbox_windows[i]
windows.append(
self.preprocess(
_bboxes, _words,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=self.max_seq_length)
)
output['windows'] = windows
else:
raise ValueError(
f"Not supported mode: {self.mode }"
)
return output
def preprocess(self, bounding_boxes, words, feature_maps, max_seq_length):
input_ids_layoutxlm = np.ones(max_seq_length, dtype=int) * self.pad_token_id_layoutxlm
attention_mask_layoutxlm = np.zeros(max_seq_length, dtype=int)
bbox = np.zeros((max_seq_length, 8), dtype=np.float32)
are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_)
list_layoutxlm_tokens = []
list_bbs = []
list_words = []
lwords = [''] * max_seq_length
box_to_token_indices = []
cum_token_idx = 0
cls_bbs = [0.0] * 8
len_overlap_tokens = 0
len_non_overlap_tokens = 0
len_valid_tokens = 0
for word_idx, (bounding_box, word) in enumerate(zip(bounding_boxes, words)):
bb = [[bounding_box[0], bounding_box[1]], [bounding_box[2], bounding_box[1]], [bounding_box[2], bounding_box[3]], [bounding_box[0], bounding_box[3]]]
layoutxlm_tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm.tokenize(word))
this_box_token_indices = []
len_valid_tokens += len(layoutxlm_tokens)
if word_idx < self.slice_interval:
len_non_overlap_tokens += len(layoutxlm_tokens)
if len(layoutxlm_tokens) == 0:
layoutxlm_tokens.append(self.unk_token_id)
if len(list_layoutxlm_tokens) + len(layoutxlm_tokens) > max_seq_length - 2:
break
list_layoutxlm_tokens += layoutxlm_tokens
# min, max clipping
for coord_idx in range(4):
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], feature_maps['width']))
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], feature_maps['height']))
bb = list(itertools.chain(*bb))
bbs = [bb for _ in range(len(layoutxlm_tokens))]
texts = [word for _ in range(len(layoutxlm_tokens))]
for _ in layoutxlm_tokens:
cum_token_idx += 1
this_box_token_indices.append(cum_token_idx)
list_bbs.extend(bbs)
list_words.extend(texts) ###
box_to_token_indices.append(this_box_token_indices)
sep_bbs = [feature_maps['width'], feature_maps['height']] * 4
# For [CLS] and [SEP]
list_layoutxlm_tokens = (
[self.cls_token_id_layoutxlm]
+ list_layoutxlm_tokens[: max_seq_length - 2]
+ [self.sep_token_id_layoutxlm]
)
if len(list_bbs) == 0:
# When len(json_obj["words"]) == 0 (no OCR result)
list_bbs = [cls_bbs] + [sep_bbs]
else: # len(list_bbs) > 0
list_bbs = [cls_bbs] + list_bbs[: max_seq_length - 2] + [sep_bbs]
# list_words = ['CLS'] + list_words[: max_seq_length - 2] + ['SEP']
list_words = [self.tokenizer_layoutxlm._cls_token] + list_words[: max_seq_length - 2] + [self.tokenizer_layoutxlm._sep_token]
len_list_layoutxlm_tokens = len(list_layoutxlm_tokens)
input_ids_layoutxlm[:len_list_layoutxlm_tokens] = list_layoutxlm_tokens
attention_mask_layoutxlm[:len_list_layoutxlm_tokens] = 1
bbox[:len_list_layoutxlm_tokens, :] = list_bbs
lwords[:len_list_layoutxlm_tokens] = list_words ###
# Normalize bbox -> 0 ~ 1
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / feature_maps['width']
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / feature_maps['height']
if self.backbone_type in ("layoutlm", "layoutxlm", "xlm-roberta"):
bbox = bbox[:, [0, 1, 4, 5]]
bbox = bbox * 1000
bbox = bbox.astype(int)
else:
assert False
st_indices = [
indices[0]
for indices in box_to_token_indices
if indices[0] < max_seq_length
]
are_box_first_tokens[st_indices] = True
assert len_list_layoutxlm_tokens == len_valid_tokens + 2
len_overlap_tokens = len_valid_tokens - len_non_overlap_tokens
ntokens = max_seq_length if max_seq_length == 512 else len_valid_tokens + 2
input_ids_layoutxlm = input_ids_layoutxlm[:ntokens]
attention_mask_layoutxlm = attention_mask_layoutxlm[:ntokens]
bbox = bbox[:ntokens]
are_box_first_tokens = are_box_first_tokens[:ntokens]
input_ids_layoutxlm = torch.from_numpy(input_ids_layoutxlm)
attention_mask_layoutxlm = torch.from_numpy(attention_mask_layoutxlm)
bbox = torch.from_numpy(bbox)
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
len_valid_tokens = torch.tensor(len_valid_tokens)
len_overlap_tokens = torch.tensor(len_overlap_tokens)
return_dict = {
"img_path": feature_maps['img_path'],
"words": list_words,
"len_overlap_tokens": len_overlap_tokens,
'len_valid_tokens': len_valid_tokens,
"image": feature_maps['image'],
"input_ids_layoutxlm": input_ids_layoutxlm,
"attention_mask_layoutxlm": attention_mask_layoutxlm,
"are_box_first_tokens": are_box_first_tokens,
"bbox": bbox,
}
return return_dict
def load_ground_truth(self, json_file):
json_obj = read_json(json_file)
width = json_obj["meta"]["imageSize"]["width"]
height = json_obj["meta"]["imageSize"]["height"]
input_ids = np.ones(self.max_seq_length, dtype=int) * self.pad_token_id_layoutxlm
bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32)
attention_mask = np.zeros(self.max_seq_length, dtype=int)
itc_labels = np.zeros(self.max_seq_length, dtype=int)
are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_)
# stc_labels stores the index of the previous token.
# A stored index of max_seq_length (512) indicates that
# this token is the initial token of a word box.
stc_labels = np.ones(self.max_seq_length, dtype=np.int64) * self.max_seq_length
el_labels = np.ones(self.max_seq_length, dtype=int) * self.max_seq_length
el_labels_from_key = np.ones(self.max_seq_length, dtype=int) * self.max_seq_length
list_tokens = []
list_bbs = []
list_words = []
box2token_span_map = []
lwords = [''] * self.max_seq_length
box_to_token_indices = []
cum_token_idx = 0
cls_bbs = [0.0] * 8
for word_idx, word in enumerate(json_obj["words"]):
this_box_token_indices = []
tokens = word["layoutxlm_tokens"]
bb = word["boundingBox"]
text = word["text"]
if len(tokens) == 0:
tokens.append(self.unk_token_id)
if len(list_tokens) + len(tokens) > self.max_seq_length - 2:
break
box2token_span_map.append(
[len(list_tokens) + 1, len(list_tokens) + len(tokens) + 1]
) # including st_idx
list_tokens += tokens
# min, max clipping
for coord_idx in range(4):
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], width))
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], height))
bb = list(itertools.chain(*bb))
bbs = [bb for _ in range(len(tokens))]
texts = [text for _ in range(len(tokens))]
for _ in tokens:
cum_token_idx += 1
this_box_token_indices.append(cum_token_idx)
list_bbs.extend(bbs)
list_words.extend(texts) ####
box_to_token_indices.append(this_box_token_indices)
sep_bbs = [width, height] * 4
# For [CLS] and [SEP]
list_tokens = (
[self.cls_token_id_layoutxlm]
+ list_tokens[: self.max_seq_length - 2]
+ [self.sep_token_id_layoutxlm]
)
if len(list_bbs) == 0:
# When len(json_obj["words"]) == 0 (no OCR result)
list_bbs = [cls_bbs] + [sep_bbs]
else: # len(list_bbs) > 0
list_bbs = [cls_bbs] + list_bbs[: self.max_seq_length - 2] + [sep_bbs]
# list_words = ['CLS'] + list_words[: self.max_seq_length - 2] + ['SEP'] ###
list_words = [self.tokenizer_layoutxlm._cls_token] + list_words[: self.max_seq_length - 2] + [self.tokenizer_layoutxlm._sep_token]
len_list_tokens = len(list_tokens)
input_ids[:len_list_tokens] = list_tokens
attention_mask[:len_list_tokens] = 1
bbox[:len_list_tokens, :] = list_bbs
lwords[:len_list_tokens] = list_words
# Normalize bbox -> 0 ~ 1
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / width
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / height
if self.backbone_type in ("layoutlm", "layoutxlm"):
bbox = bbox[:, [0, 1, 4, 5]]
bbox = bbox * 1000
bbox = bbox.astype(int)
else:
assert False
st_indices = [
indices[0]
for indices in box_to_token_indices
if indices[0] < self.max_seq_length
]
are_box_first_tokens[st_indices] = True
# Label
classes_dic = json_obj["parse"]["class"]
for class_name in self.class_names:
if class_name == "others":
continue
if class_name not in classes_dic:
continue
for word_list in classes_dic[class_name]:
is_first, last_word_idx = True, -1
for word_idx in word_list:
if word_idx >= len(box_to_token_indices):
break
box2token_list = box_to_token_indices[word_idx]
for converted_word_idx in box2token_list:
if converted_word_idx >= self.max_seq_length:
break # out of idx
if is_first:
itc_labels[converted_word_idx] = self.class_idx_dic[
class_name
]
is_first, last_word_idx = False, converted_word_idx
else:
stc_labels[converted_word_idx] = last_word_idx
last_word_idx = converted_word_idx
# Label
relations = json_obj["parse"]["relations"]
for relation in relations:
if relation[0] >= len(box2token_span_map) or relation[1] >= len(
box2token_span_map
):
continue
if (
box2token_span_map[relation[0]][0] >= self.max_seq_length
or box2token_span_map[relation[1]][0] >= self.max_seq_length
):
continue
word_from = box2token_span_map[relation[0]][0]
word_to = box2token_span_map[relation[1]][0]
# el_labels[word_to] = word_from
if el_labels[word_to] != 512 and el_labels_from_key[word_to] != 512:
continue
# if self.second_relations == 1:
# if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)):
# el_labels[word_to] = word_from # pair of (header, key) or (header-value)
# else:
#### 1st relation => ['key, 'value']
#### 2st relation => ['header', 'key'or'value']
if itc_labels[word_from] == 2 and itc_labels[word_to] == 3:
el_labels_from_key[word_to] = word_from # pair of (key-value)
if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)):
el_labels[word_to] = word_from # pair of (header, key) or (header-value)
input_ids = torch.from_numpy(input_ids)
bbox = torch.from_numpy(bbox)
attention_mask = torch.from_numpy(attention_mask)
itc_labels = torch.from_numpy(itc_labels)
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
stc_labels = torch.from_numpy(stc_labels)
el_labels = torch.from_numpy(el_labels)
el_labels_from_key = torch.from_numpy(el_labels_from_key)
return_dict = {
# "image": feature_maps,
"input_ids": input_ids,
"bbox": bbox,
"words": lwords,
"attention_mask": attention_mask,
"itc_labels": itc_labels,
"are_box_first_tokens": are_box_first_tokens,
"stc_labels": stc_labels,
"el_labels": el_labels,
"el_labels_from_key": el_labels_from_key
}
return return_dict
class DocumentKVUProcess(KVUProcess):
def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, max_window_count, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0):
super().__init__(tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length, mode)
self.max_window_count = max_window_count
self.pad_token_id = self.pad_token_id_layoutxlm
self.cls_token_id = self.cls_token_id_layoutxlm
self.sep_token_id = self.sep_token_id_layoutxlm
self.unk_token_id = self.unk_token_id_layoutxlm
self.tokenizer = self.tokenizer_layoutxlm
def __call__(self, img_path: str, ocr_path: str) -> list:
if (self.run_ocr == 1) and (not os.path.exists(ocr_path)):
process_img(img_path, "tmp.txt", self.ocr_engine, export_img=False)
ocr_path = "tmp.txt"
lbboxes, lwords = read_ocr_result_from_txt(ocr_path)
lwords = post_process_basic_ocr(lwords)
width, height = imagesize.get(img_path)
images = [Image.open(img_path).convert("RGB")]
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
output = self.preprocess(lbboxes, lwords,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path})
return output
def preprocess(self, bounding_boxes, words, feature_maps):
n_words = len(words)
output_dicts = {'windows': [], 'documents': []}
n_empty_windows = 0
for i in range(self.max_window_count):
input_ids = np.ones(self.max_seq_length, dtype=int) * self.pad_token_id
bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32)
attention_mask = np.zeros(self.max_seq_length, dtype=int)
are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_)
if n_words == 0:
n_empty_windows += 1
output_dicts['windows'].append({
"image": feature_maps['image'],
"input_ids": torch.from_numpy(input_ids),
"bbox": torch.from_numpy(bbox),
"words": [],
"attention_mask": torch.from_numpy(attention_mask),
"are_box_first_tokens": torch.from_numpy(are_box_first_tokens),
})
continue
start_word_idx = i * self.window_size
stop_word_idx = min(n_words, (i+1)*self.window_size)
if start_word_idx >= stop_word_idx:
n_empty_windows += 1
output_dicts['windows'].append(output_dicts['windows'][-1])
continue
list_tokens = []
list_bbs = []
list_words = []
lwords = [''] * self.max_seq_length
box_to_token_indices = []
cum_token_idx = 0
cls_bbs = [0.0] * 8
for _, (bounding_box, word) in enumerate(zip(bounding_boxes[start_word_idx:stop_word_idx], words[start_word_idx:stop_word_idx])):
bb = [[bounding_box[0], bounding_box[1]], [bounding_box[2], bounding_box[1]], [bounding_box[2], bounding_box[3]], [bounding_box[0], bounding_box[3]]]
tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(word))
this_box_token_indices = []
if len(tokens) == 0:
tokens.append(self.unk_token_id)
if len(list_tokens) + len(tokens) > self.max_seq_length - 2:
break
list_tokens += tokens
# min, max clipping
for coord_idx in range(4):
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], feature_maps['width']))
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], feature_maps['height']))
bb = list(itertools.chain(*bb))
bbs = [bb for _ in range(len(tokens))]
texts = [word for _ in range(len(tokens))]
for _ in tokens:
cum_token_idx += 1
this_box_token_indices.append(cum_token_idx)
list_bbs.extend(bbs)
list_words.extend(texts) ###
box_to_token_indices.append(this_box_token_indices)
sep_bbs = [feature_maps['width'], feature_maps['height']] * 4
# For [CLS] and [SEP]
list_tokens = (
[self.cls_token_id]
+ list_tokens[: self.max_seq_length - 2]
+ [self.sep_token_id]
)
if len(list_bbs) == 0:
# When len(json_obj["words"]) == 0 (no OCR result)
list_bbs = [cls_bbs] + [sep_bbs]
else: # len(list_bbs) > 0
list_bbs = [cls_bbs] + list_bbs[: self.max_seq_length - 2] + [sep_bbs]
if len(list_words) < 510:
list_words.extend(['</p>' for _ in range(510 - len(list_words))])
list_words = [self.tokenizer._cls_token] + list_words[: self.max_seq_length - 2] + [self.tokenizer._sep_token]
len_list_tokens = len(list_tokens)
input_ids[:len_list_tokens] = list_tokens
attention_mask[:len_list_tokens] = 1
bbox[:len_list_tokens, :] = list_bbs
lwords[:len_list_tokens] = list_words ###
# Normalize bbox -> 0 ~ 1
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / feature_maps['width']
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / feature_maps['height']
bbox = bbox[:, [0, 1, 4, 5]]
bbox = bbox * 1000
bbox = bbox.astype(int)
st_indices = [
indices[0]
for indices in box_to_token_indices
if indices[0] < self.max_seq_length
]
are_box_first_tokens[st_indices] = True
input_ids = torch.from_numpy(input_ids)
bbox = torch.from_numpy(bbox)
attention_mask = torch.from_numpy(attention_mask)
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
return_dict = {
"image": feature_maps['image'],
"input_ids": input_ids,
"bbox": bbox,
"words": list_words,
"attention_mask": attention_mask,
"are_box_first_tokens": are_box_first_tokens,
}
output_dicts["windows"].append(return_dict)
attention_mask = torch.cat([o['attention_mask'] for o in output_dicts["windows"]])
are_box_first_tokens = torch.cat([o['are_box_first_tokens'] for o in output_dicts["windows"]])
if n_empty_windows > 0:
attention_mask[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=int))
are_box_first_tokens[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=np.bool_))
bbox = torch.cat([o['bbox'] for o in output_dicts["windows"]])
words = []
for o in output_dicts['windows']:
words.extend(o['words'])
return_dict = {
"attention_mask": attention_mask,
"bbox": bbox,
"are_box_first_tokens": are_box_first_tokens,
"n_empty_windows": n_empty_windows,
"words": words
}
output_dicts['documents'] = return_dict
return output_dicts