139 lines
5.7 KiB
Python
Executable File
139 lines
5.7 KiB
Python
Executable File
import yaml
|
|
import numpy as np
|
|
|
|
from .config.sift_based_aligner import config
|
|
from .modules.sift_based_aligner import SIFTBasedAligner
|
|
from .utils.common import read_json
|
|
from common.utils.word_formation import Word, words_to_lines
|
|
|
|
|
|
def calc_pct_overlapped_area(bboxes1, bboxes2):
|
|
# assert True
|
|
assert len(bboxes1.shape) == 2 and bboxes1.shape[1] == 4
|
|
assert len(bboxes2.shape) == 2 and bboxes2.shape[1] == 4
|
|
|
|
bboxes1 = bboxes1.copy()
|
|
bboxes2 = bboxes2.copy()
|
|
|
|
x11, y11, x12, y12 = np.split(bboxes1, 4, axis=1)
|
|
x21, y21, x22, y22 = np.split(bboxes2, 4, axis=1)
|
|
xA = np.maximum(x11, np.transpose(x21))
|
|
yA = np.maximum(y11, np.transpose(y21))
|
|
xB = np.minimum(x12, np.transpose(x22))
|
|
yB = np.minimum(y12, np.transpose(y22))
|
|
interArea = np.maximum((xB - xA + 1), 0) * np.maximum((yB - yA + 1), 0)
|
|
boxBArea = (x22 - x21 + 1) * (y22 - y21 + 1)
|
|
boxBArea = np.tile(boxBArea, (1, len(bboxes1)))
|
|
iou = interArea / boxBArea.T
|
|
return iou
|
|
|
|
|
|
class Predictor:
|
|
def __init__(self, setting_file="setting.yml"):
|
|
with open(setting_file) as f:
|
|
# use safe_load instead load
|
|
self.setting = yaml.safe_load(f)
|
|
self.config = self.setting["templates"]["config"]
|
|
|
|
def _align(self, config, temp_name, image):
|
|
# init aligner
|
|
aligner = SIFTBasedAligner(**config)
|
|
metadata = [{"doc_type": temp_name}]
|
|
aligned_images = aligner.run_alige([image], metadata)
|
|
aligned_image = aligned_images[0]
|
|
return aligned_image
|
|
|
|
def _reorder_words(self, boxes):
|
|
arr_x1 = boxes[:, 0]
|
|
return np.argsort(arr_x1)
|
|
|
|
def _asign_words_to_field(
|
|
self, boxes, contents, types, page_template_info, threshold=0.8
|
|
):
|
|
field_coords = [element["box"] for element in page_template_info["fields"]]
|
|
field_coords = np.array(field_coords)
|
|
field_coords = field_coords.astype(float)
|
|
field_coords = field_coords.astype(int)
|
|
field_names = [element["label"] for element in page_template_info["fields"]]
|
|
field_types = [
|
|
"checkbox" if element["label"].startswith("checkbox") else "word"
|
|
for element in page_template_info["fields"]
|
|
]
|
|
boxes = np.array(boxes[0])
|
|
print(field_coords)
|
|
print(boxes)
|
|
print(field_coords.shape, boxes.shape)
|
|
area_pct = calc_pct_overlapped_area(field_coords, boxes)
|
|
|
|
results = dict()
|
|
for row_score, field, _type in zip(area_pct, field_names, field_types):
|
|
if _type == "checkbox":
|
|
inds = np.where(row_score > threshold)[0]
|
|
inds = [i for i in inds if types[i] == "checkbox"]
|
|
results[field] = dict()
|
|
results[field]["value"] = contents[inds[0]] if len(inds) > 0 else None
|
|
results[field]["boxes"] = boxes[inds[0]] if len(inds) > 0 else None
|
|
else:
|
|
inds = np.where(row_score > threshold)[0]
|
|
field_word_boxes = boxes[inds]
|
|
sorted_inds = inds[self._reorder_words(field_word_boxes)]
|
|
|
|
results[field] = dict()
|
|
results[field]["words"] = [contents[i] for i in sorted_inds]
|
|
lines = self._get_line_content(boxes[sorted_inds], results[field]["words"])
|
|
results[field]["value"] = '\n'.join(lines).strip()
|
|
results[field]["boxes"] = boxes[sorted_inds]
|
|
return results
|
|
|
|
def _get_line_content(self, boxes, contents):
|
|
list_words = []
|
|
for box, text in zip(boxes, contents):
|
|
bndbox = [int(j) for j in box]
|
|
list_words.append(
|
|
Word(
|
|
text=text,
|
|
bndbox=bndbox,
|
|
)
|
|
)
|
|
list_lines, _ = words_to_lines(list_words)
|
|
line_texts = [line.text for line in list_lines]
|
|
return line_texts
|
|
|
|
|
|
def align_image(self, image, template_json, template_image_dir, temp_name):
|
|
"""Run TemplateMaching main
|
|
|
|
Args:
|
|
documents (dict): document then document classification
|
|
template_json (dict):
|
|
example:
|
|
{
|
|
"pos01": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/add_fields/pos01.json",
|
|
"pos04": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/add_fields/pos04.json",
|
|
"pos02": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/add_fields/pos02.json",
|
|
"pos03": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/pos03_fields_checkbox.json",
|
|
"pos08": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/add_fields/pos08.json",
|
|
"pos05": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/add_fields/pos05.json",
|
|
"pos06": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/add_fields/pos06.json",
|
|
"cccd_front": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/cccd_front.json",
|
|
}
|
|
template_image_dir (str): path to template image dir
|
|
|
|
Returns:
|
|
dict: content then template matching
|
|
"""
|
|
|
|
config = self.config.copy()
|
|
config["template_info"] = template_json
|
|
config["template_im_dir"] = template_image_dir
|
|
aligned_image = self._align(config, temp_name, image)
|
|
return aligned_image
|
|
|
|
def template_based_extractor(self, batch_boxes, texts, doc_page, template_json):
|
|
field_data = self._asign_words_to_field(
|
|
batch_boxes,
|
|
texts,
|
|
doc_page["types"],
|
|
template_json,
|
|
)
|
|
return field_data |