sbt-idp/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/preprocess.py

601 lines
24 KiB
Python
Raw Normal View History

2023-11-30 11:22:16 +00:00
import os
from typing import Any
import numpy as np
import pandas as pd
import imagesize
import itertools
from PIL import Image
import argparse
import torch
from utils.utils import read_ocr_result_from_txt, read_json, post_process_basic_ocr
from utils.run_ocr import load_ocr_engine, process_img
from lightning_modules.utils import sliding_windows
class KVUProcess:
def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0):
self.tokenizer_layoutxlm = tokenizer_layoutxlm
self.feature_extractor = feature_extractor
self.max_seq_length = max_seq_length
self.backbone_type = backbone_type
self.class_names = class_names
self.slice_interval = slice_interval
self.window_size = window_size
self.run_ocr = run_ocr
self.mode = mode
self.pad_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._pad_token)
self.cls_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._cls_token)
self.sep_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._sep_token)
self.unk_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._unk_token)
self.class_idx_dic = dict(
[(class_name, idx) for idx, class_name in enumerate(self.class_names)]
)
self.ocr_engine = None
if self.run_ocr == 1:
self.ocr_engine = load_ocr_engine()
def __call__(self, img_path: str, ocr_path: str) -> list:
if (self.run_ocr == 1) or (not os.path.exists(ocr_path)):
process_img(img_path, "tmp.txt", self.ocr_engine, export_img=False)
ocr_path = "tmp.txt"
lbboxes, lwords = read_ocr_result_from_txt(ocr_path)
lwords = post_process_basic_ocr(lwords)
bbox_windows = sliding_windows(lbboxes, self.window_size, self.slice_interval)
word_windows = sliding_windows(lwords, self.window_size, self.slice_interval)
assert len(bbox_windows) == len(word_windows), f"Shape of lbboxes and lwords after sliding window is not the same {len(bbox_windows)} # {len(word_windows)}"
width, height = imagesize.get(img_path)
images = [Image.open(img_path).convert("RGB")]
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
if self.mode == 0:
output = self.preprocess(lbboxes, lwords,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=self.max_seq_length)
elif self.mode == 1:
output = {}
windows = []
for i in range(len(bbox_windows)):
_words = word_windows[i]
_bboxes = bbox_windows[i]
windows.append(
self.preprocess(
_bboxes, _words,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=self.max_seq_length)
)
output['windows'] = windows
elif self.mode == 2:
output = {}
windows = []
output['doduments'] = self.preprocess(lbboxes, lwords,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=2048)
for i in range(len(bbox_windows)):
_words = word_windows[i]
_bboxes = bbox_windows[i]
windows.append(
self.preprocess(
_bboxes, _words,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=self.max_seq_length)
)
output['windows'] = windows
else:
raise ValueError(
f"Not supported mode: {self.mode }"
)
return output
def preprocess(self, bounding_boxes, words, feature_maps, max_seq_length):
input_ids_layoutxlm = np.ones(max_seq_length, dtype=int) * self.pad_token_id_layoutxlm
attention_mask_layoutxlm = np.zeros(max_seq_length, dtype=int)
bbox = np.zeros((max_seq_length, 8), dtype=np.float32)
are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_)
list_layoutxlm_tokens = []
list_bbs = []
list_words = []
lwords = [''] * max_seq_length
box_to_token_indices = []
cum_token_idx = 0
cls_bbs = [0.0] * 8
len_overlap_tokens = 0
len_non_overlap_tokens = 0
len_valid_tokens = 0
for word_idx, (bounding_box, word) in enumerate(zip(bounding_boxes, words)):
bb = [[bounding_box[0], bounding_box[1]], [bounding_box[2], bounding_box[1]], [bounding_box[2], bounding_box[3]], [bounding_box[0], bounding_box[3]]]
layoutxlm_tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm.tokenize(word))
this_box_token_indices = []
len_valid_tokens += len(layoutxlm_tokens)
if word_idx < self.slice_interval:
len_non_overlap_tokens += len(layoutxlm_tokens)
if len(layoutxlm_tokens) == 0:
layoutxlm_tokens.append(self.unk_token_id)
if len(list_layoutxlm_tokens) + len(layoutxlm_tokens) > max_seq_length - 2:
break
list_layoutxlm_tokens += layoutxlm_tokens
# min, max clipping
for coord_idx in range(4):
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], feature_maps['width']))
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], feature_maps['height']))
bb = list(itertools.chain(*bb))
bbs = [bb for _ in range(len(layoutxlm_tokens))]
texts = [word for _ in range(len(layoutxlm_tokens))]
for _ in layoutxlm_tokens:
cum_token_idx += 1
this_box_token_indices.append(cum_token_idx)
list_bbs.extend(bbs)
list_words.extend(texts) ###
box_to_token_indices.append(this_box_token_indices)
sep_bbs = [feature_maps['width'], feature_maps['height']] * 4
# For [CLS] and [SEP]
list_layoutxlm_tokens = (
[self.cls_token_id_layoutxlm]
+ list_layoutxlm_tokens[: max_seq_length - 2]
+ [self.sep_token_id_layoutxlm]
)
if len(list_bbs) == 0:
# When len(json_obj["words"]) == 0 (no OCR result)
list_bbs = [cls_bbs] + [sep_bbs]
else: # len(list_bbs) > 0
list_bbs = [cls_bbs] + list_bbs[: max_seq_length - 2] + [sep_bbs]
# list_words = ['CLS'] + list_words[: max_seq_length - 2] + ['SEP']
list_words = [self.tokenizer_layoutxlm._cls_token] + list_words[: max_seq_length - 2] + [self.tokenizer_layoutxlm._sep_token]
len_list_layoutxlm_tokens = len(list_layoutxlm_tokens)
input_ids_layoutxlm[:len_list_layoutxlm_tokens] = list_layoutxlm_tokens
attention_mask_layoutxlm[:len_list_layoutxlm_tokens] = 1
bbox[:len_list_layoutxlm_tokens, :] = list_bbs
lwords[:len_list_layoutxlm_tokens] = list_words ###
# Normalize bbox -> 0 ~ 1
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / feature_maps['width']
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / feature_maps['height']
if self.backbone_type in ("layoutlm", "layoutxlm", "xlm-roberta"):
bbox = bbox[:, [0, 1, 4, 5]]
bbox = bbox * 1000
bbox = bbox.astype(int)
else:
assert False
st_indices = [
indices[0]
for indices in box_to_token_indices
if indices[0] < max_seq_length
]
are_box_first_tokens[st_indices] = True
assert len_list_layoutxlm_tokens == len_valid_tokens + 2
len_overlap_tokens = len_valid_tokens - len_non_overlap_tokens
ntokens = max_seq_length if max_seq_length == 512 else len_valid_tokens + 2
input_ids_layoutxlm = input_ids_layoutxlm[:ntokens]
attention_mask_layoutxlm = attention_mask_layoutxlm[:ntokens]
bbox = bbox[:ntokens]
are_box_first_tokens = are_box_first_tokens[:ntokens]
input_ids_layoutxlm = torch.from_numpy(input_ids_layoutxlm)
attention_mask_layoutxlm = torch.from_numpy(attention_mask_layoutxlm)
bbox = torch.from_numpy(bbox)
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
len_valid_tokens = torch.tensor(len_valid_tokens)
len_overlap_tokens = torch.tensor(len_overlap_tokens)
return_dict = {
"img_path": feature_maps['img_path'],
"words": list_words,
"len_overlap_tokens": len_overlap_tokens,
'len_valid_tokens': len_valid_tokens,
"image": feature_maps['image'],
"input_ids_layoutxlm": input_ids_layoutxlm,
"attention_mask_layoutxlm": attention_mask_layoutxlm,
"are_box_first_tokens": are_box_first_tokens,
"bbox": bbox,
}
return return_dict
def load_ground_truth(self, json_file):
json_obj = read_json(json_file)
width = json_obj["meta"]["imageSize"]["width"]
height = json_obj["meta"]["imageSize"]["height"]
input_ids = np.ones(self.max_seq_length, dtype=int) * self.pad_token_id_layoutxlm
bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32)
attention_mask = np.zeros(self.max_seq_length, dtype=int)
itc_labels = np.zeros(self.max_seq_length, dtype=int)
are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_)
# stc_labels stores the index of the previous token.
# A stored index of max_seq_length (512) indicates that
# this token is the initial token of a word box.
stc_labels = np.ones(self.max_seq_length, dtype=np.int64) * self.max_seq_length
el_labels = np.ones(self.max_seq_length, dtype=int) * self.max_seq_length
el_labels_from_key = np.ones(self.max_seq_length, dtype=int) * self.max_seq_length
list_tokens = []
list_bbs = []
list_words = []
box2token_span_map = []
lwords = [''] * self.max_seq_length
box_to_token_indices = []
cum_token_idx = 0
cls_bbs = [0.0] * 8
for word_idx, word in enumerate(json_obj["words"]):
this_box_token_indices = []
tokens = word["layoutxlm_tokens"]
bb = word["boundingBox"]
text = word["text"]
if len(tokens) == 0:
tokens.append(self.unk_token_id)
if len(list_tokens) + len(tokens) > self.max_seq_length - 2:
break
box2token_span_map.append(
[len(list_tokens) + 1, len(list_tokens) + len(tokens) + 1]
) # including st_idx
list_tokens += tokens
# min, max clipping
for coord_idx in range(4):
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], width))
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], height))
bb = list(itertools.chain(*bb))
bbs = [bb for _ in range(len(tokens))]
texts = [text for _ in range(len(tokens))]
for _ in tokens:
cum_token_idx += 1
this_box_token_indices.append(cum_token_idx)
list_bbs.extend(bbs)
list_words.extend(texts) ####
box_to_token_indices.append(this_box_token_indices)
sep_bbs = [width, height] * 4
# For [CLS] and [SEP]
list_tokens = (
[self.cls_token_id_layoutxlm]
+ list_tokens[: self.max_seq_length - 2]
+ [self.sep_token_id_layoutxlm]
)
if len(list_bbs) == 0:
# When len(json_obj["words"]) == 0 (no OCR result)
list_bbs = [cls_bbs] + [sep_bbs]
else: # len(list_bbs) > 0
list_bbs = [cls_bbs] + list_bbs[: self.max_seq_length - 2] + [sep_bbs]
# list_words = ['CLS'] + list_words[: self.max_seq_length - 2] + ['SEP'] ###
list_words = [self.tokenizer_layoutxlm._cls_token] + list_words[: self.max_seq_length - 2] + [self.tokenizer_layoutxlm._sep_token]
len_list_tokens = len(list_tokens)
input_ids[:len_list_tokens] = list_tokens
attention_mask[:len_list_tokens] = 1
bbox[:len_list_tokens, :] = list_bbs
lwords[:len_list_tokens] = list_words
# Normalize bbox -> 0 ~ 1
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / width
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / height
if self.backbone_type in ("layoutlm", "layoutxlm"):
bbox = bbox[:, [0, 1, 4, 5]]
bbox = bbox * 1000
bbox = bbox.astype(int)
else:
assert False
st_indices = [
indices[0]
for indices in box_to_token_indices
if indices[0] < self.max_seq_length
]
are_box_first_tokens[st_indices] = True
# Label
classes_dic = json_obj["parse"]["class"]
for class_name in self.class_names:
if class_name == "others":
continue
if class_name not in classes_dic:
continue
for word_list in classes_dic[class_name]:
is_first, last_word_idx = True, -1
for word_idx in word_list:
if word_idx >= len(box_to_token_indices):
break
box2token_list = box_to_token_indices[word_idx]
for converted_word_idx in box2token_list:
if converted_word_idx >= self.max_seq_length:
break # out of idx
if is_first:
itc_labels[converted_word_idx] = self.class_idx_dic[
class_name
]
is_first, last_word_idx = False, converted_word_idx
else:
stc_labels[converted_word_idx] = last_word_idx
last_word_idx = converted_word_idx
# Label
relations = json_obj["parse"]["relations"]
for relation in relations:
if relation[0] >= len(box2token_span_map) or relation[1] >= len(
box2token_span_map
):
continue
if (
box2token_span_map[relation[0]][0] >= self.max_seq_length
or box2token_span_map[relation[1]][0] >= self.max_seq_length
):
continue
word_from = box2token_span_map[relation[0]][0]
word_to = box2token_span_map[relation[1]][0]
# el_labels[word_to] = word_from
if el_labels[word_to] != 512 and el_labels_from_key[word_to] != 512:
continue
# if self.second_relations == 1:
# if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)):
# el_labels[word_to] = word_from # pair of (header, key) or (header-value)
# else:
#### 1st relation => ['key, 'value']
#### 2st relation => ['header', 'key'or'value']
if itc_labels[word_from] == 2 and itc_labels[word_to] == 3:
el_labels_from_key[word_to] = word_from # pair of (key-value)
if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)):
el_labels[word_to] = word_from # pair of (header, key) or (header-value)
input_ids = torch.from_numpy(input_ids)
bbox = torch.from_numpy(bbox)
attention_mask = torch.from_numpy(attention_mask)
itc_labels = torch.from_numpy(itc_labels)
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
stc_labels = torch.from_numpy(stc_labels)
el_labels = torch.from_numpy(el_labels)
el_labels_from_key = torch.from_numpy(el_labels_from_key)
return_dict = {
# "image": feature_maps,
"input_ids": input_ids,
"bbox": bbox,
"words": lwords,
"attention_mask": attention_mask,
"itc_labels": itc_labels,
"are_box_first_tokens": are_box_first_tokens,
"stc_labels": stc_labels,
"el_labels": el_labels,
"el_labels_from_key": el_labels_from_key
}
return return_dict
class DocumentKVUProcess(KVUProcess):
def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, max_window_count, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0):
super().__init__(tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length, mode)
self.max_window_count = max_window_count
self.pad_token_id = self.pad_token_id_layoutxlm
self.cls_token_id = self.cls_token_id_layoutxlm
self.sep_token_id = self.sep_token_id_layoutxlm
self.unk_token_id = self.unk_token_id_layoutxlm
self.tokenizer = self.tokenizer_layoutxlm
def __call__(self, img_path: str, ocr_path: str) -> list:
if (self.run_ocr == 1) and (not os.path.exists(ocr_path)):
process_img(img_path, "tmp.txt", self.ocr_engine, export_img=False)
ocr_path = "tmp.txt"
lbboxes, lwords = read_ocr_result_from_txt(ocr_path)
lwords = post_process_basic_ocr(lwords)
width, height = imagesize.get(img_path)
images = [Image.open(img_path).convert("RGB")]
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
output = self.preprocess(lbboxes, lwords,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path})
return output
def preprocess(self, bounding_boxes, words, feature_maps):
n_words = len(words)
output_dicts = {'windows': [], 'documents': []}
n_empty_windows = 0
for i in range(self.max_window_count):
input_ids = np.ones(self.max_seq_length, dtype=int) * self.pad_token_id
bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32)
attention_mask = np.zeros(self.max_seq_length, dtype=int)
are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_)
if n_words == 0:
n_empty_windows += 1
output_dicts['windows'].append({
"image": feature_maps['image'],
"input_ids": torch.from_numpy(input_ids),
"bbox": torch.from_numpy(bbox),
"words": [],
"attention_mask": torch.from_numpy(attention_mask),
"are_box_first_tokens": torch.from_numpy(are_box_first_tokens),
})
continue
start_word_idx = i * self.window_size
stop_word_idx = min(n_words, (i+1)*self.window_size)
if start_word_idx >= stop_word_idx:
n_empty_windows += 1
output_dicts['windows'].append(output_dicts['windows'][-1])
continue
list_tokens = []
list_bbs = []
list_words = []
lwords = [''] * self.max_seq_length
box_to_token_indices = []
cum_token_idx = 0
cls_bbs = [0.0] * 8
for _, (bounding_box, word) in enumerate(zip(bounding_boxes[start_word_idx:stop_word_idx], words[start_word_idx:stop_word_idx])):
bb = [[bounding_box[0], bounding_box[1]], [bounding_box[2], bounding_box[1]], [bounding_box[2], bounding_box[3]], [bounding_box[0], bounding_box[3]]]
tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(word))
this_box_token_indices = []
if len(tokens) == 0:
tokens.append(self.unk_token_id)
if len(list_tokens) + len(tokens) > self.max_seq_length - 2:
break
list_tokens += tokens
# min, max clipping
for coord_idx in range(4):
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], feature_maps['width']))
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], feature_maps['height']))
bb = list(itertools.chain(*bb))
bbs = [bb for _ in range(len(tokens))]
texts = [word for _ in range(len(tokens))]
for _ in tokens:
cum_token_idx += 1
this_box_token_indices.append(cum_token_idx)
list_bbs.extend(bbs)
list_words.extend(texts) ###
box_to_token_indices.append(this_box_token_indices)
sep_bbs = [feature_maps['width'], feature_maps['height']] * 4
# For [CLS] and [SEP]
list_tokens = (
[self.cls_token_id]
+ list_tokens[: self.max_seq_length - 2]
+ [self.sep_token_id]
)
if len(list_bbs) == 0:
# When len(json_obj["words"]) == 0 (no OCR result)
list_bbs = [cls_bbs] + [sep_bbs]
else: # len(list_bbs) > 0
list_bbs = [cls_bbs] + list_bbs[: self.max_seq_length - 2] + [sep_bbs]
if len(list_words) < 510:
list_words.extend(['</p>' for _ in range(510 - len(list_words))])
list_words = [self.tokenizer._cls_token] + list_words[: self.max_seq_length - 2] + [self.tokenizer._sep_token]
len_list_tokens = len(list_tokens)
input_ids[:len_list_tokens] = list_tokens
attention_mask[:len_list_tokens] = 1
bbox[:len_list_tokens, :] = list_bbs
lwords[:len_list_tokens] = list_words ###
# Normalize bbox -> 0 ~ 1
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / feature_maps['width']
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / feature_maps['height']
bbox = bbox[:, [0, 1, 4, 5]]
bbox = bbox * 1000
bbox = bbox.astype(int)
st_indices = [
indices[0]
for indices in box_to_token_indices
if indices[0] < self.max_seq_length
]
are_box_first_tokens[st_indices] = True
input_ids = torch.from_numpy(input_ids)
bbox = torch.from_numpy(bbox)
attention_mask = torch.from_numpy(attention_mask)
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
return_dict = {
"image": feature_maps['image'],
"input_ids": input_ids,
"bbox": bbox,
"words": list_words,
"attention_mask": attention_mask,
"are_box_first_tokens": are_box_first_tokens,
}
output_dicts["windows"].append(return_dict)
attention_mask = torch.cat([o['attention_mask'] for o in output_dicts["windows"]])
are_box_first_tokens = torch.cat([o['are_box_first_tokens'] for o in output_dicts["windows"]])
if n_empty_windows > 0:
attention_mask[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=int))
are_box_first_tokens[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=np.bool_))
bbox = torch.cat([o['bbox'] for o in output_dicts["windows"]])
words = []
for o in output_dicts['windows']:
words.extend(o['words'])
return_dict = {
"attention_mask": attention_mask,
"bbox": bbox,
"are_box_first_tokens": are_box_first_tokens,
"n_empty_windows": n_empty_windows,
"words": words
}
output_dicts['documents'] = return_dict
return output_dicts