607 lines
21 KiB
Python
Executable File
607 lines
21 KiB
Python
Executable File
from builtins import dict
|
|
from common.utils.global_variables import *
|
|
import logging
|
|
import logging.config
|
|
from utils.logging.logging import LOGGER_CONFIG
|
|
# Load the logging configuration
|
|
logging.config.dictConfig(LOGGER_CONFIG)
|
|
# Get the logger
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MIN_IOU_HEIGHT = 0.7
|
|
MIN_WIDTH_LINE_RATIO = 0.05
|
|
|
|
|
|
class Word:
|
|
def __init__(
|
|
self,
|
|
text="",
|
|
image=None,
|
|
conf_detect=0.0,
|
|
conf_cls=0.0,
|
|
bndbox=None,
|
|
kie_label="",
|
|
):
|
|
self.type = "word"
|
|
self.text = text
|
|
self.image = image
|
|
self.conf_detect = conf_detect
|
|
self.conf_cls = conf_cls
|
|
self.boundingbox = bndbox if bndbox is not None else [-1, -1, -1, -1]# [left, top,right,bot] coordinate of top-left and bottom-right point
|
|
self.word_id = 0 # id of word
|
|
self.word_group_id = 0 # id of word_group which instance belongs to
|
|
self.line_id = 0 # id of line which instance belongs to
|
|
self.paragraph_id = 0 # id of line which instance belongs to
|
|
self.kie_label = kie_label
|
|
|
|
def invalid_size(self):
|
|
return (self.boundingbox[2] - self.boundingbox[0]) * (
|
|
self.boundingbox[3] - self.boundingbox[1]
|
|
) > 0
|
|
|
|
def is_special_word(self):
|
|
left, top, right, bottom = self.boundingbox
|
|
width, height = right - left, bottom - top
|
|
text = self.text
|
|
|
|
if text is None:
|
|
return True
|
|
|
|
if len(text) >= 7:
|
|
no_digits = sum(c.isdigit() for c in text)
|
|
return no_digits / len(text) >= 0.3
|
|
|
|
return False
|
|
|
|
|
|
class Word_group:
|
|
def __init__(self):
|
|
self.type = "word_group"
|
|
self.list_words = [] # dict of word instances
|
|
self.word_group_id = 0 # word group id
|
|
self.line_id = 0 # id of line which instance belongs to
|
|
self.paragraph_id = 0 # id of paragraph which instance belongs to
|
|
self.text = ""
|
|
self.boundingbox = [-1, -1, -1, -1]
|
|
self.kie_label = ""
|
|
|
|
def add_word(self, word: Word): # add a word instance to the word_group
|
|
if word.text != "✪":
|
|
for w in self.list_words:
|
|
if word.word_id == w.word_id:
|
|
logger.info("Word id collision")
|
|
return False
|
|
word.word_group_id = self.word_group_id #
|
|
word.line_id = self.line_id
|
|
word.paragraph_id = self.paragraph_id
|
|
self.list_words.append(word)
|
|
self.text += " " + word.text
|
|
if self.boundingbox == [-1, -1, -1, -1]:
|
|
self.boundingbox = word.boundingbox
|
|
else:
|
|
self.boundingbox = [
|
|
min(self.boundingbox[0], word.boundingbox[0]),
|
|
min(self.boundingbox[1], word.boundingbox[1]),
|
|
max(self.boundingbox[2], word.boundingbox[2]),
|
|
max(self.boundingbox[3], word.boundingbox[3]),
|
|
]
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def update_word_group_id(self, new_word_group_id):
|
|
self.word_group_id = new_word_group_id
|
|
for i in range(len(self.list_words)):
|
|
self.list_words[i].word_group_id = new_word_group_id
|
|
|
|
def update_kie_label(self):
|
|
list_kie_label = [word.kie_label for word in self.list_words]
|
|
dict_kie = dict()
|
|
for label in list_kie_label:
|
|
if label not in dict_kie:
|
|
dict_kie[label] = 1
|
|
else:
|
|
dict_kie[label] += 1
|
|
max_value = max(list(dict_kie.values()))
|
|
list_keys = list(dict_kie.keys())
|
|
list_values = list(dict_kie.values())
|
|
self.kie_label = list_keys[list_values.index(max_value)]
|
|
|
|
def update_text(self): # update text after changing positions of words in list word
|
|
text = ""
|
|
for word in self.list_words:
|
|
text += " " + word.text
|
|
self.text = text
|
|
|
|
|
|
class Line:
|
|
def __init__(self):
|
|
self.type = "line"
|
|
self.list_word_groups = [] # list of Word_group instances in the line
|
|
self.line_id = 0 # id of line in the paragraph
|
|
self.paragraph_id = 0 # id of paragraph which instance belongs to
|
|
self.text = ""
|
|
self.boundingbox = [-1, -1, -1, -1]
|
|
|
|
def add_group(self, word_group: Word_group): # add a word_group instance
|
|
if word_group.list_words is not None:
|
|
for wg in self.list_word_groups:
|
|
if word_group.word_group_id == wg.word_group_id:
|
|
logger.info("Word_group id collision")
|
|
return False
|
|
|
|
self.list_word_groups.append(word_group)
|
|
self.text += word_group.text
|
|
word_group.paragraph_id = self.paragraph_id
|
|
word_group.line_id = self.line_id
|
|
|
|
for i in range(len(word_group.list_words)):
|
|
word_group.list_words[
|
|
i
|
|
].paragraph_id = self.paragraph_id # set paragraph_id for word
|
|
word_group.list_words[i].line_id = self.line_id # set line_id for word
|
|
return True
|
|
return False
|
|
|
|
def update_line_id(self, new_line_id):
|
|
self.line_id = new_line_id
|
|
for i in range(len(self.list_word_groups)):
|
|
self.list_word_groups[i].line_id = new_line_id
|
|
for j in range(len(self.list_word_groups[i].list_words)):
|
|
self.list_word_groups[i].list_words[j].line_id = new_line_id
|
|
|
|
def merge_word(self, word): # word can be a Word instance or a Word_group instance
|
|
if word.text != "✪":
|
|
if self.boundingbox == [-1, -1, -1, -1]:
|
|
self.boundingbox = word.boundingbox
|
|
else:
|
|
self.boundingbox = [
|
|
min(self.boundingbox[0], word.boundingbox[0]),
|
|
min(self.boundingbox[1], word.boundingbox[1]),
|
|
max(self.boundingbox[2], word.boundingbox[2]),
|
|
max(self.boundingbox[3], word.boundingbox[3]),
|
|
]
|
|
self.list_word_groups.append(word)
|
|
self.text += " " + word.text
|
|
return True
|
|
return False
|
|
|
|
def __cal_ratio(self, top1, bottom1, top2, bottom2):
|
|
sorted_vals = sorted([top1, bottom1, top2, bottom2])
|
|
intersection = sorted_vals[2] - sorted_vals[1]
|
|
min_height = min(bottom1 - top1, bottom2 - top2)
|
|
if min_height == 0:
|
|
return -1
|
|
ratio = intersection / min_height
|
|
return ratio
|
|
|
|
def __cal_ratio_height(self, top1, bottom1, top2, bottom2):
|
|
|
|
height1, height2 = top1 - bottom1, top2 - bottom2
|
|
ratio_height = float(max(height1, height2)) / float(min(height1, height2))
|
|
return ratio_height
|
|
|
|
def in_same_line(self, input_line, thresh=0.7):
|
|
# calculate iou in vertical direction
|
|
_, top1, _, bottom1 = self.boundingbox
|
|
_, top2, _, bottom2 = input_line.boundingbox
|
|
|
|
ratio = self.__cal_ratio(top1, bottom1, top2, bottom2)
|
|
ratio_height = self.__cal_ratio_height(top1, bottom1, top2, bottom2)
|
|
|
|
if (
|
|
(top1 in range(top2, bottom2) or top2 in range(top1, bottom1))
|
|
and ratio >= thresh
|
|
and (ratio_height < 2)
|
|
):
|
|
return True
|
|
return False
|
|
|
|
|
|
class Paragraph:
|
|
def __init__(self, id=0, lines=None):
|
|
self.list_lines = (
|
|
lines if lines is not None else []
|
|
) # list of all lines in the paragraph
|
|
self.paragraph_id = id # index of paragraph in the ist of paragraph
|
|
self.text = ""
|
|
self.boundingbox = [-1, -1, -1, -1]
|
|
|
|
def add_line(self, line: Line): # add a line instance
|
|
if line.list_word_groups is not None:
|
|
for l in self.list_lines:
|
|
if line.line_id == l.line_id:
|
|
logger.info("Line id collision")
|
|
return False
|
|
for i in range(len(line.list_word_groups)):
|
|
line.list_word_groups[
|
|
i
|
|
].paragraph_id = (
|
|
self.paragraph_id
|
|
) # set paragraph id for every word group in line
|
|
for j in range(len(line.list_word_groups[i].list_words)):
|
|
line.list_word_groups[i].list_words[
|
|
j
|
|
].paragraph_id = (
|
|
self.paragraph_id
|
|
) # set paragraph id for every word in word groups
|
|
line.paragraph_id = self.paragraph_id # set paragraph id for line
|
|
self.list_lines.append(line) # add line to paragraph
|
|
self.text += " " + line.text
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def update_paragraph_id(
|
|
self, new_paragraph_id
|
|
): # update new paragraph_id for all lines, word_groups, words inside paragraph
|
|
self.paragraph_id = new_paragraph_id
|
|
for i in range(len(self.list_lines)):
|
|
self.list_lines[
|
|
i
|
|
].paragraph_id = new_paragraph_id # set new paragraph_id for line
|
|
for j in range(len(self.list_lines[i].list_word_groups)):
|
|
self.list_lines[i].list_word_groups[
|
|
j
|
|
].paragraph_id = new_paragraph_id # set new paragraph_id for word_group
|
|
for k in range(len(self.list_lines[i].list_word_groups[j].list_words)):
|
|
self.list_lines[i].list_word_groups[j].list_words[
|
|
k
|
|
].paragraph_id = new_paragraph_id # set new paragraph id for word
|
|
return True
|
|
|
|
|
|
def resize_to_original(
|
|
boundingbox, scale
|
|
): # resize coordinates to match size of original image
|
|
left, top, right, bottom = boundingbox
|
|
left *= scale[1]
|
|
right *= scale[1]
|
|
top *= scale[0]
|
|
bottom *= scale[0]
|
|
return [left, top, right, bottom]
|
|
|
|
|
|
def check_iomin(word: Word, word_group: Word_group):
|
|
min_height = min(
|
|
word.boundingbox[3] - word.boundingbox[1],
|
|
word_group.boundingbox[3] - word_group.boundingbox[1],
|
|
)
|
|
intersect = min(word.boundingbox[3], word_group.boundingbox[3]) - max(
|
|
word.boundingbox[1], word_group.boundingbox[1]
|
|
)
|
|
if intersect / min_height > 0.7:
|
|
return True
|
|
return False
|
|
|
|
|
|
def prepare_line(words):
|
|
lines = []
|
|
visited = [False] * len(words)
|
|
for id_word, word in enumerate(words):
|
|
if word.invalid_size() == 0:
|
|
continue
|
|
new_line = True
|
|
for i in range(len(lines)):
|
|
if (
|
|
lines[i].in_same_line(word) and not visited[id_word]
|
|
): # check if word is in the same line with lines[i]
|
|
lines[i].merge_word(word)
|
|
new_line = False
|
|
visited[id_word] = True
|
|
|
|
if new_line == True:
|
|
new_line = Line()
|
|
new_line.merge_word(word)
|
|
lines.append(new_line)
|
|
|
|
# logger.info(len(lines))
|
|
# sort line from top to bottom according top coordinate
|
|
lines.sort(key=lambda x: x.boundingbox[1])
|
|
return lines
|
|
|
|
|
|
def __create_word_group(word, word_group_id):
|
|
new_word_group = Word_group()
|
|
new_word_group.word_group_id = word_group_id
|
|
new_word_group.add_word(word)
|
|
|
|
return new_word_group
|
|
|
|
|
|
def __sort_line(line):
|
|
line.list_word_groups.sort(
|
|
key=lambda x: x.boundingbox[0]
|
|
) # sort word in lines from left to right
|
|
|
|
return line
|
|
|
|
|
|
def __merge_text_for_line(line):
|
|
line.text = ""
|
|
for word in line.list_word_groups:
|
|
line.text += " " + word.text
|
|
|
|
return line
|
|
|
|
|
|
def __update_list_word_groups(line, word_group_id, word_id, line_width):
|
|
|
|
old_list_word_group = line.list_word_groups
|
|
list_word_groups = []
|
|
|
|
inital_word_group = __create_word_group(old_list_word_group[0], word_group_id)
|
|
old_list_word_group[0].word_id = word_id
|
|
list_word_groups.append(inital_word_group)
|
|
word_group_id += 1
|
|
word_id += 1
|
|
|
|
for word in old_list_word_group[1:]:
|
|
check_word_group = True
|
|
word.word_id = word_id
|
|
word_id += 1
|
|
|
|
if (
|
|
(not list_word_groups[-1].text.endswith(":"))
|
|
and (
|
|
(word.boundingbox[0] - list_word_groups[-1].boundingbox[2]) / line_width
|
|
< MIN_WIDTH_LINE_RATIO
|
|
)
|
|
and check_iomin(word, list_word_groups[-1])
|
|
):
|
|
list_word_groups[-1].add_word(word)
|
|
check_word_group = False
|
|
|
|
if check_word_group:
|
|
new_word_group = __create_word_group(word, word_group_id)
|
|
list_word_groups.append(new_word_group)
|
|
word_group_id += 1
|
|
line.list_word_groups = list_word_groups
|
|
return line, word_group_id, word_id
|
|
|
|
|
|
def construct_word_groups_in_each_line(lines):
|
|
line_id = 0
|
|
word_group_id = 0
|
|
word_id = 0
|
|
for i in range(len(lines)):
|
|
if len(lines[i].list_word_groups) == 0:
|
|
continue
|
|
|
|
# left, top ,right, bottom
|
|
line_width = lines[i].boundingbox[2] - lines[i].boundingbox[0] # right - left
|
|
|
|
lines[i] = __sort_line(lines[i])
|
|
|
|
# update text for lines after sorting
|
|
lines[i] = __merge_text_for_line(lines[i])
|
|
|
|
lines[i], word_group_id, word_id = __update_list_word_groups(
|
|
lines[i], word_group_id, word_id, line_width
|
|
)
|
|
lines[i].update_line_id(line_id)
|
|
line_id += 1
|
|
return lines
|
|
|
|
|
|
def words_to_lines(words, check_special_lines=True): # words is list of Word instance
|
|
# sort word by top
|
|
words.sort(key=lambda x: (x.boundingbox[1], x.boundingbox[0]))
|
|
number_of_word = len(words)
|
|
# logger.info(number_of_word)
|
|
# sort list words to list lines, which have not contained word_group yet
|
|
lines = prepare_line(words)
|
|
|
|
# construct word_groups in each line
|
|
lines = construct_word_groups_in_each_line(lines)
|
|
return lines, number_of_word
|
|
|
|
|
|
def near(word_group1: Word_group, word_group2: Word_group):
|
|
min_height = min(
|
|
word_group1.boundingbox[3] - word_group1.boundingbox[1],
|
|
word_group2.boundingbox[3] - word_group2.boundingbox[1],
|
|
)
|
|
overlap = min(word_group1.boundingbox[3], word_group2.boundingbox[3]) - max(
|
|
word_group1.boundingbox[1], word_group2.boundingbox[1]
|
|
)
|
|
|
|
if overlap > 0:
|
|
return True
|
|
if abs(overlap / min_height) < 1.5:
|
|
logger.info("near enough", abs(overlap / min_height), overlap, min_height)
|
|
return True
|
|
return False
|
|
|
|
|
|
def calculate_iou_and_near(wg1: Word_group, wg2: Word_group):
|
|
min_height = min(
|
|
wg1.boundingbox[3] - wg1.boundingbox[1], wg2.boundingbox[3] - wg2.boundingbox[1]
|
|
)
|
|
overlap = min(wg1.boundingbox[3], wg2.boundingbox[3]) - max(
|
|
wg1.boundingbox[1], wg2.boundingbox[1]
|
|
)
|
|
iou = overlap / min_height
|
|
distance = min(
|
|
abs(wg1.boundingbox[0] - wg2.boundingbox[2]),
|
|
abs(wg1.boundingbox[2] - wg2.boundingbox[0]),
|
|
)
|
|
if iou > 0.7 and distance < 0.5 * (wg1.boundingboxp[2] - wg1.boundingbox[0]):
|
|
return True
|
|
return False
|
|
|
|
|
|
def construct_word_groups_to_kie_label(list_word_groups: list):
|
|
kie_dict = dict()
|
|
for wg in list_word_groups:
|
|
if wg.kie_label == "other":
|
|
continue
|
|
if wg.kie_label not in kie_dict:
|
|
kie_dict[wg.kie_label] = [wg]
|
|
else:
|
|
kie_dict[wg.kie_label].append(wg)
|
|
|
|
new_dict = dict()
|
|
for key, value in kie_dict.items():
|
|
if len(value) == 1:
|
|
new_dict[key] = value
|
|
continue
|
|
|
|
value.sort(key=lambda x: x.boundingbox[1])
|
|
new_dict[key] = value
|
|
return new_dict
|
|
|
|
|
|
def invoice_construct_word_groups_to_kie_label(list_word_groups: list):
|
|
kie_dict = dict()
|
|
|
|
for wg in list_word_groups:
|
|
if wg.kie_label == "other":
|
|
continue
|
|
if wg.kie_label not in kie_dict:
|
|
kie_dict[wg.kie_label] = [wg]
|
|
else:
|
|
kie_dict[wg.kie_label].append(wg)
|
|
|
|
return kie_dict
|
|
|
|
|
|
def postprocess_total_value(kie_dict):
|
|
if "total_in_words_value" not in kie_dict:
|
|
return kie_dict
|
|
|
|
for k, value in kie_dict.items():
|
|
if k == "total_in_words_value":
|
|
continue
|
|
l = []
|
|
for v in value:
|
|
if v.boundingbox[3] <= kie_dict["total_in_words_value"][0].boundingbox[3]:
|
|
l.append(v)
|
|
|
|
if len(l) != 0:
|
|
kie_dict[k] = l
|
|
|
|
return kie_dict
|
|
|
|
|
|
def postprocess_tax_code_value(kie_dict):
|
|
if "buyer_tax_code_value" in kie_dict or "seller_tax_code_value" not in kie_dict:
|
|
return kie_dict
|
|
|
|
kie_dict["buyer_tax_code_value"] = []
|
|
for v in kie_dict["seller_tax_code_value"]:
|
|
if "buyer_name_key" in kie_dict and (
|
|
v.boundingbox[3] > kie_dict["buyer_name_key"][0].boundingbox[3]
|
|
or near(v, kie_dict["buyer_name_key"][0])
|
|
):
|
|
kie_dict["buyer_tax_code_value"].append(v)
|
|
continue
|
|
|
|
if "buyer_name_value" in kie_dict and (
|
|
v.boundingbox[3] > kie_dict["buyer_name_value"][0].boundingbox[3]
|
|
or near(v, kie_dict["buyer_name_value"][0])
|
|
):
|
|
kie_dict["buyer_tax_code_value"].append(v)
|
|
continue
|
|
|
|
if "buyer_address_value" in kie_dict and near(
|
|
kie_dict["buyer_address_value"][0], v
|
|
):
|
|
kie_dict["buyer_tax_code_value"].append(v)
|
|
return kie_dict
|
|
|
|
|
|
def postprocess_tax_code_key(kie_dict):
|
|
if "buyer_tax_code_key" in kie_dict or "seller_tax_code_key" not in kie_dict:
|
|
return kie_dict
|
|
kie_dict["buyer_tax_code_key"] = []
|
|
for v in kie_dict["seller_tax_code_key"]:
|
|
if "buyer_name_key" in kie_dict and (
|
|
v.boundingbox[3] > kie_dict["buyer_name_key"][0].boundingbox[3]
|
|
or near(v, kie_dict["buyer_name_key"][0])
|
|
):
|
|
kie_dict["buyer_tax_code_key"].append(v)
|
|
continue
|
|
|
|
if "buyer_name_value" in kie_dict and (
|
|
v.boundingbox[3] > kie_dict["buyer_name_value"][0].boundingbox[3]
|
|
or near(v, kie_dict["buyer_name_value"][0])
|
|
):
|
|
kie_dict["buyer_tax_code_key"].append(v)
|
|
continue
|
|
|
|
if "buyer_address_value" in kie_dict and near(
|
|
kie_dict["buyer_address_value"][0], v
|
|
):
|
|
kie_dict["buyer_tax_code_key"].append(v)
|
|
|
|
return kie_dict
|
|
|
|
|
|
def invoice_postprocess(kie_dict: dict):
|
|
# all keys or values which are below total_in_words_value will be thrown away
|
|
kie_dict = postprocess_total_value(kie_dict)
|
|
kie_dict = postprocess_tax_code_value(kie_dict)
|
|
kie_dict = postprocess_tax_code_key(kie_dict)
|
|
return kie_dict
|
|
|
|
|
|
def throw_overlapping_words(list_words):
|
|
new_list = [list_words[0]]
|
|
for word in list_words:
|
|
overlap = False
|
|
area = (word.boundingbox[2] - word.boundingbox[0]) * (
|
|
word.boundingbox[3] - word.boundingbox[1]
|
|
)
|
|
for word2 in new_list:
|
|
area2 = (word2.boundingbox[2] - word2.boundingbox[0]) * (
|
|
word2.boundingbox[3] - word2.boundingbox[1]
|
|
)
|
|
xmin_intersect = max(word.boundingbox[0], word2.boundingbox[0])
|
|
xmax_intersect = min(word.boundingbox[2], word2.boundingbox[2])
|
|
ymin_intersect = max(word.boundingbox[1], word2.boundingbox[1])
|
|
ymax_intersect = min(word.boundingbox[3], word2.boundingbox[3])
|
|
if xmax_intersect < xmin_intersect or ymax_intersect < ymin_intersect:
|
|
continue
|
|
|
|
area_intersect = (xmax_intersect - xmin_intersect) * (
|
|
ymax_intersect - ymin_intersect
|
|
)
|
|
if area_intersect / area > 0.7 or area_intersect / area2 > 0.7:
|
|
overlap = True
|
|
if overlap == False:
|
|
new_list.append(word)
|
|
return new_list
|
|
|
|
|
|
class Box:
|
|
def __init__(self, xmin=0, ymin=0, xmax=0, ymax=0, label="", kie_label=""):
|
|
self.xmax = xmax
|
|
self.ymax = ymax
|
|
self.xmin = xmin
|
|
self.ymin = ymin
|
|
self.label = label
|
|
self.kie_label = kie_label
|
|
|
|
|
|
def check_iou(box1: Word, box2: Box, threshold=0.9):
|
|
area1 = (box1.boundingbox[2] - box1.boundingbox[0]) * (
|
|
box1.boundingbox[3] - box1.boundingbox[1]
|
|
)
|
|
area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin)
|
|
xmin_intersect = max(box1.boundingbox[0], box2.xmin)
|
|
ymin_intersect = max(box1.boundingbox[1], box2.ymin)
|
|
xmax_intersect = min(box1.boundingbox[2], box2.xmax)
|
|
ymax_intersect = min(box1.boundingbox[3], box2.ymax)
|
|
if xmax_intersect < xmin_intersect or ymax_intersect < ymin_intersect:
|
|
area_intersect = 0
|
|
else:
|
|
area_intersect = (xmax_intersect - xmin_intersect) * (
|
|
ymax_intersect - ymin_intersect
|
|
)
|
|
union = area1 + area2 - area_intersect
|
|
iou = area_intersect / union
|
|
if iou > threshold:
|
|
return True
|
|
return False
|