sbt-idp/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/functions.py

465 lines
19 KiB
Python
Raw Normal View History

2023-11-30 11:22:16 +00:00
import os
import cv2
import json
import torch
import glob
import re
import numpy as np
from tqdm import tqdm
from pdf2image import convert_from_path
from dicttoxml import dicttoxml
from word_preprocess import vat_standardizer, get_string, ap_standardizer
from kvu_dictionary import vat_dictionary, ap_dictionary
2024-07-05 13:14:47 +00:00
import logging
import logging.config
from utils.logging.logging import LOGGER_CONFIG
# Load the logging configuration
logging.config.dictConfig(LOGGER_CONFIG)
# Get the logger
logger = logging.getLogger(__name__)
2023-11-30 11:22:16 +00:00
def create_dir(save_dir=''):
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
else:
2024-07-05 13:14:47 +00:00
logger.info("DIR already existed.")
logger.info('Save dir : {}'.format(save_dir))
2023-11-30 11:22:16 +00:00
def pdf2image(pdf_dir, save_dir):
pdf_files = glob.glob(f'{pdf_dir}/*.pdf')
2024-07-05 13:14:47 +00:00
logger.info('No. pdf files:', len(pdf_files))
2023-11-30 11:22:16 +00:00
for file in tqdm(pdf_files):
pages = convert_from_path(file, 500)
for i, page in enumerate(pages):
page.save(os.path.join(save_dir, os.path.basename(file).replace('.pdf', f'_{i}.jpg')), 'JPEG')
2024-07-05 13:14:47 +00:00
logger.info('Done!!!')
2023-11-30 11:22:16 +00:00
def xyxy2xywh(bbox):
return [
float(bbox[0]),
float(bbox[1]),
float(bbox[2]) - float(bbox[0]),
float(bbox[3]) - float(bbox[1]),
]
def write_to_json(file_path, content):
with open(file_path, mode='w', encoding='utf8') as f:
json.dump(content, f, ensure_ascii=False)
def read_json(file_path):
with open(file_path, 'r') as f:
return json.load(f)
def read_xml(file_path):
with open(file_path, 'r') as xml_file:
return xml_file.read()
def write_to_xml(file_path, content):
with open(file_path, mode="w", encoding='utf8') as f:
f.write(content)
def write_to_xml_from_dict(file_path, content):
xml = dicttoxml(content)
xml = content
xml_decode = xml.decode()
with open(file_path, mode="w") as f:
f.write(xml_decode)
def load_ocr_result(ocr_path):
with open(ocr_path, 'r') as f:
lines = f.read().splitlines()
preds = []
for line in lines:
preds.append(line.split('\t'))
return preds
def post_process_basic_ocr(lwords: list) -> list:
pp_lwords = []
for word in lwords:
pp_lwords.append(word.replace("", " "))
return pp_lwords
def read_ocr_result_from_txt(file_path: str):
'''
return list of bounding boxes, list of words
'''
with open(file_path, 'r') as f:
lines = f.read().splitlines()
boxes, words = [], []
for line in lines:
if line == "":
continue
word_info = line.split("\t")
if len(word_info) == 6:
x1, y1, x2, y2, text, _ = word_info
elif len(word_info) == 5:
x1, y1, x2, y2, text = word_info
x1, y1, x2, y2 = int(float(x1)), int(float(y1)), int(float(x2)), int(float(y2))
if text and text != " ":
words.append(text)
boxes.append((x1, y1, x2, y2))
return boxes, words
def get_colormap():
return {
'others': (0, 0, 255), # others: red
'title': (0, 255, 255), # title: yellow
'key': (255, 0, 0), # key: blue
'value': (0, 255, 0), # value: green
'header': (233, 197, 15), # header
'group': (0, 128, 128), # group
'relation': (0, 0, 255)# (128, 128, 128), # relation
}
def visualize(image, bbox, pr_class_words, pr_relations, color_map, labels=['others', 'title', 'key', 'value', 'header'], thickness=1):
exif = image._getexif()
orientation = None
if exif is not None:
orientation = exif.get(0x0112)
# Convert the PIL image to OpenCV format
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
# Rotate the image in OpenCV if necessary
if orientation == 3:
image = cv2.rotate(image, cv2.ROTATE_180)
elif orientation == 6:
image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
elif orientation == 8:
image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
else:
image = np.asarray(image)
if len(image.shape) == 2:
image = np.repeat(image[:, :, np.newaxis], 3, axis=2)
assert len(image.shape) == 3
if orientation is not None and orientation == 6:
width, height, _ = image.shape
else:
height, width, _ = image.shape
if len(pr_class_words) > 0:
id2label = {k: labels[k] for k in range(len(labels))}
for lb, groups in enumerate(pr_class_words):
if lb == 0:
continue
for group_id, group in enumerate(groups):
for i, word_id in enumerate(group):
x0, y0, x1, y1 = int(bbox[word_id][0]*width/1000), int(bbox[word_id][1]*height/1000), int(bbox[word_id][2]*width/1000), int(bbox[word_id][3]*height/1000)
cv2.rectangle(image, (x0, y0), (x1, y1), color=color_map[id2label[lb]], thickness=thickness)
if i == 0:
x_center0, y_center0 = int((x0+x1)/2), int((y0+y1)/2)
else:
x_center1, y_center1 = int((x0+x1)/2), int((y0+y1)/2)
cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['group'], thickness=thickness)
x_center0, y_center0 = x_center1, y_center1
if len(pr_relations) > 0:
for pair in pr_relations:
xyxy0 = int(bbox[pair[0]][0]*width/1000), int(bbox[pair[0]][1]*height/1000), int(bbox[pair[0]][2]*width/1000), int(bbox[pair[0]][3]*height/1000)
xyxy1 = int(bbox[pair[1]][0]*width/1000), int(bbox[pair[1]][1]*height/1000), int(bbox[pair[1]][2]*width/1000), int(bbox[pair[1]][3]*height/1000)
x_center0, y_center0 = int((xyxy0[0] + xyxy0[2])/2), int((xyxy0[1] + xyxy0[3])/2)
x_center1, y_center1 = int((xyxy1[0] + xyxy1[2])/2), int((xyxy1[1] + xyxy1[3])/2)
cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['relation'], thickness=thickness)
return image
def get_pairs(json: list, rel_from: str, rel_to: str) -> dict:
outputs = {}
for pair in json:
is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}}
for element in pair:
if element['class'] in (rel_from, rel_to):
is_rel[element['class']]['status'] = 1
is_rel[element['class']]['value'] = element
if all([v['status'] == 1 for _, v in is_rel.items()]):
outputs[is_rel[rel_to]['value']['group_id']] = [is_rel[rel_from]['value']['group_id'], is_rel[rel_to]['value']['group_id']]
return outputs
def get_table_relations(json: list, header_key_pairs: dict, rel_from="key", rel_to="value") -> dict:
list_keys = list(header_key_pairs.keys())
relations = {k: [] for k in list_keys}
for pair in json:
is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}}
for element in pair:
if element['class'] == rel_from and element['group_id'] in list_keys:
is_rel[rel_from]['status'] = 1
is_rel[rel_from]['value'] = element
if element['class'] == rel_to:
is_rel[rel_to]['status'] = 1
is_rel[rel_to]['value'] = element
if all([v['status'] == 1 for _, v in is_rel.items()]):
relations[is_rel[rel_from]['value']['group_id']].append(is_rel[rel_to]['value']['group_id'])
return relations
def get_key2values_relations(key_value_pairs: dict):
triple_linkings = {}
for value_group_id, key_value_pair in key_value_pairs.items():
key_group_id = key_value_pair[0]
if key_group_id not in list(triple_linkings.keys()):
triple_linkings[key_group_id] = []
triple_linkings[key_group_id].append(value_group_id)
return triple_linkings
def merged_token_to_wordgroup(class_words: list, lwords, labels) -> dict:
word_groups = {}
id2class = {i: labels[i] for i in range(len(labels))}
for class_id, lwgroups_in_class in enumerate(class_words):
for ltokens_in_wgroup in lwgroups_in_class:
group_id = ltokens_in_wgroup[0]
ltokens_to_ltexts = [lwords[token] for token in ltokens_in_wgroup]
text_string = get_string(ltokens_to_ltexts)
word_groups[group_id] = {
'group_id': group_id,
'text': text_string,
'class': id2class[class_id],
'tokens': ltokens_in_wgroup
}
return word_groups
def verify_linking_id(word_groups: dict, linking_id: int) -> int:
if linking_id not in list(word_groups):
for wg_id, _word_group in word_groups.items():
if linking_id in _word_group['tokens']:
return wg_id
return linking_id
def matched_wordgroup_relations(word_groups:dict, lrelations: list) -> list:
outputs = []
for pair in lrelations:
wg_from = verify_linking_id(word_groups, pair[0])
wg_to = verify_linking_id(word_groups, pair[1])
try:
outputs.append([word_groups[wg_from], word_groups[wg_to]])
except:
2024-07-05 13:14:47 +00:00
logger.info('Not valid pair:', wg_from, wg_to)
2023-11-30 11:22:16 +00:00
return outputs
def export_kvu_outputs(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
word_groups = merged_token_to_wordgroup(class_words, lwords, labels)
linking_pairs = matched_wordgroup_relations(word_groups, lrelations)
header_key = get_pairs(linking_pairs, rel_from='header', rel_to='key') # => {key_group_id: [header_group_id, key_group_id]}
header_value = get_pairs(linking_pairs, rel_from='header', rel_to='value') # => {value_group_id: [header_group_id, value_group_id]}
key_value = get_pairs(linking_pairs, rel_from='key', rel_to='value') # => {value_group_id: [key_group_id, value_group_id]}
# table_relations = get_table_relations(linking_pairs, header_key) # => {key_group_id: [value_group_id1, value_groupid2, ...]}
key2values_relations = get_key2values_relations(key_value) # => {key_group_id: [value_group_id1, value_groupid2, ...]}
triplet_pairs = []
single_pairs = []
table = []
2024-07-05 13:14:47 +00:00
# logger.info('key2values_relations', key2values_relations)
2023-11-30 11:22:16 +00:00
for key_group_id, list_value_group_ids in key2values_relations.items():
if len(list_value_group_ids) == 0: continue
elif len(list_value_group_ids) == 1:
value_group_id = list_value_group_ids[0]
single_pairs.append({word_groups[key_group_id]['text']: {
'text': word_groups[value_group_id]['text'],
'id': value_group_id,
'class': "value"
}})
else:
item = []
for value_group_id in list_value_group_ids:
if value_group_id not in header_value.keys():
header_name_for_value = "non-header"
else:
header_group_id = header_value[value_group_id][0]
header_name_for_value = word_groups[header_group_id]['text']
item.append({
'text': word_groups[value_group_id]['text'],
'header': header_name_for_value,
'id': value_group_id,
'class': 'value'
})
if key_group_id not in list(header_key.keys()):
triplet_pairs.append({
word_groups[key_group_id]['text']: item
})
else:
header_group_id = header_key[key_group_id][0]
header_name_for_key = word_groups[header_group_id]['text']
item.append({
'text': word_groups[key_group_id]['text'],
'header': header_name_for_key,
'id': key_group_id,
'class': 'key'
})
table.append({key_group_id: item})
if len(table) > 0:
table = sorted(table, key=lambda x: list(x.keys())[0])
table = [v for item in table for k, v in item.items()]
outputs = {}
outputs['single'] = sorted(single_pairs, key=lambda x: int(float(list(x.values())[0]['id'])))
outputs['triplet'] = triplet_pairs
outputs['table'] = table
file_path = os.path.join(os.path.dirname(file_path), 'kvu_results', os.path.basename(file_path))
write_to_json(file_path, outputs)
return outputs
def export_kvu_for_VAT_invoice(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
vat_outputs = {}
outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels)
# List of items in table
table = []
for single_item in outputs['table']:
item = {k: [] for k in list(vat_dictionary(header=True).keys())}
for cell in single_item:
header_name, score, proceessed_text = vat_standardizer(cell['header'], threshold=0.75, header=True)
if header_name in list(item.keys()):
# item[header_name] = value['text']
item[header_name].append({
'content': cell['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': cell['id']
})
for header_name, value in item.items():
if len(value) == 0:
item[header_name] = None
continue
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
table.append(item)
# VAT Information
single_pairs = {k: [] for k in list(vat_dictionary(header=False).keys())}
for pair in outputs['single']:
for key_name, value in pair.items():
key_name, score, proceessed_text = ap_standardizer(key_name, threshold=0.8, header=False)
2024-07-05 13:14:47 +00:00
# logger.info(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
2023-11-30 11:22:16 +00:00
if key_name in list(single_pairs):
single_pairs[key_name].append({
'content': value['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': value['id']
})
2024-07-05 13:14:47 +00:00
# logger.info('='*10, file_path)
# logger.info(vat_info)
2023-11-30 11:22:16 +00:00
# Combine VAT information and table
vat_outputs = {k: None for k in list(single_pairs)}
for key_name, list_potential_value in single_pairs.items():
if key_name in ("Ngày, tháng, năm lập hóa đơn"):
if len(list_potential_value) == 1:
vat_outputs[key_name] = list_potential_value[0]['content']
else:
date_time = {'day': 'dd', 'month': 'mm', 'year': 'yyyy'}
for value in list_potential_value:
date_time[value['processed_key_name']] = re.sub("[^0-9]", "", value['content'])
vat_outputs[key_name] = f"{date_time['day']}/{date_time['month']}/{date_time['year']}"
else:
if len(list_potential_value) == 0: continue
if key_name in ("Mã số thuế người bán"):
selected_value = min(list_potential_value, key=lambda x: x['token_id']) # Get first tax code
vat_outputs[key_name] = selected_value['content'].replace(' ', '')
else:
selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score
vat_outputs[key_name] = selected_value['content']
vat_outputs['table'] = table
write_to_json(file_path, vat_outputs)
def export_kvu_for_SDSAP(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels)
# List of items in table
table = []
for single_item in outputs['table']:
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
for cell in single_item:
header_name, score, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=True)
2024-07-05 13:14:47 +00:00
# logger.info(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}")
2023-11-30 11:22:16 +00:00
if header_name in list(item.keys()):
item[header_name].append({
'content': cell['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': cell['id']
})
for header_name, value in item.items():
if len(value) == 0:
item[header_name] = None
continue
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
table.append(item)
triplet_pairs = []
for single_item in outputs['triplet']:
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
is_item_valid = 0
for key_name, list_value in single_item.items():
for value in list_value:
if value['header'] == "non-header":
continue
header_name, score, proceessed_text = ap_standardizer(value['header'], threshold=0.8, header=True)
if header_name in list(item.keys()):
is_item_valid = 1
item[header_name].append({
'content': value['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': value['id']
})
if is_item_valid == 1:
for header_name, value in item.items():
if len(value) == 0:
item[header_name] = None
continue
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
item['productname'] = key_name
# triplet_pairs.append({key_name: new_item})
triplet_pairs.append(item)
single_pairs = {k: [] for k in list(ap_dictionary(header=False).keys())}
for pair in outputs['single']:
for key_name, value in pair.items():
key_name, score, proceessed_text = ap_standardizer(key_name, threshold=0.8, header=False)
2024-07-05 13:14:47 +00:00
# logger.info(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
2023-11-30 11:22:16 +00:00
if key_name in list(single_pairs):
single_pairs[key_name].append({
'content': value['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': value['id']
})
ap_outputs = {k: None for k in list(single_pairs)}
for key_name, list_potential_value in single_pairs.items():
if len(list_potential_value) == 0: continue
selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score
ap_outputs[key_name] = selected_value['content']
table = table + triplet_pairs
ap_outputs['table'] = table
# ap_outputs['triplet'] = triplet_pairs
write_to_json(file_path, ap_outputs)