sbt-idp/cope2n-ai-fi/common/AnyKey_Value/utils/utils.py

554 lines
23 KiB
Python
Executable File

import os
import cv2
import json
import torch
import glob
import re
import numpy as np
from tqdm import tqdm
from pdf2image import convert_from_path
from dicttoxml import dicttoxml
from word_preprocess import vat_standardizer, get_string, ap_standardizer, post_process_for_item
from utils.kvu_dictionary import vat_dictionary, ap_dictionary
import logging
import logging.config
from utils.logging.logging import LOGGER_CONFIG
# Load the logging configuration
logging.config.dictConfig(LOGGER_CONFIG)
# Get the logger
logger = logging.getLogger(__name__)
def create_dir(save_dir=''):
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
else:
logger.info("DIR already existed.")
logger.info('Save dir : {}'.format(save_dir))
def pdf2image(pdf_dir, save_dir):
pdf_files = glob.glob(f'{pdf_dir}/*.pdf')
logger.info('No. pdf files:', len(pdf_files))
for file in tqdm(pdf_files):
pages = convert_from_path(file, 500)
for i, page in enumerate(pages):
page.save(os.path.join(save_dir, os.path.basename(file).replace('.pdf', f'_{i}.jpg')), 'JPEG')
logger.info('Done!!!')
def xyxy2xywh(bbox):
return [
float(bbox[0]),
float(bbox[1]),
float(bbox[2]) - float(bbox[0]),
float(bbox[3]) - float(bbox[1]),
]
def write_to_json(file_path, content):
with open(file_path, mode='w', encoding='utf8') as f:
json.dump(content, f, ensure_ascii=False)
def read_json(file_path):
with open(file_path, 'r') as f:
return json.load(f)
def read_xml(file_path):
with open(file_path, 'r') as xml_file:
return xml_file.read()
def write_to_xml(file_path, content):
with open(file_path, mode="w", encoding='utf8') as f:
f.write(content)
def write_to_xml_from_dict(file_path, content):
xml = dicttoxml(content)
xml = content
xml_decode = xml.decode()
with open(file_path, mode="w") as f:
f.write(xml_decode)
def load_ocr_result(ocr_path):
with open(ocr_path, 'r') as f:
lines = f.read().splitlines()
preds = []
for line in lines:
preds.append(line.split('\t'))
return preds
def post_process_basic_ocr(lwords: list) -> list:
pp_lwords = []
for word in lwords:
pp_lwords.append(word.replace("", " "))
return pp_lwords
def read_ocr_result_from_txt(file_path: str):
'''
return list of bounding boxes, list of words
'''
with open(file_path, 'r') as f:
lines = f.read().splitlines()
boxes, words = [], []
for line in lines:
if line == "":
continue
word_info = line.split("\t")
if len(word_info) == 6:
x1, y1, x2, y2, text, _ = word_info
elif len(word_info) == 5:
x1, y1, x2, y2, text = word_info
x1, y1, x2, y2 = int(float(x1)), int(float(y1)), int(float(x2)), int(float(y2))
if text and text != " ":
words.append(text)
boxes.append((x1, y1, x2, y2))
return boxes, words
def get_colormap():
return {
'others': (0, 0, 255), # others: red
'title': (0, 255, 255), # title: yellow
'key': (255, 0, 0), # key: blue
'value': (0, 255, 0), # value: green
'header': (233, 197, 15), # header
'group': (0, 128, 128), # group
'relation': (0, 0, 255)# (128, 128, 128), # relation
}
def convert_image(image):
exif = image._getexif()
orientation = None
if exif is not None:
orientation = exif.get(0x0112)
# Convert the PIL image to OpenCV format
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
# Rotate the image in OpenCV if necessary
if orientation == 3:
image = cv2.rotate(image, cv2.ROTATE_180)
elif orientation == 6:
image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
elif orientation == 8:
image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
else:
image = np.asarray(image)
if len(image.shape) == 2:
image = np.repeat(image[:, :, np.newaxis], 3, axis=2)
assert len(image.shape) == 3
return image, orientation
def visualize(image, bbox, pr_class_words, pr_relations, color_map, labels=['others', 'title', 'key', 'value', 'header'], thickness=1):
image, orientation = convert_image(image)
if orientation is not None and orientation == 6:
width, height, _ = image.shape
else:
height, width, _ = image.shape
if len(pr_class_words) > 0:
id2label = {k: labels[k] for k in range(len(labels))}
for lb, groups in enumerate(pr_class_words):
if lb == 0:
continue
for group_id, group in enumerate(groups):
for i, word_id in enumerate(group):
x0, y0, x1, y1 = int(bbox[word_id][0]*width/1000), int(bbox[word_id][1]*height/1000), int(bbox[word_id][2]*width/1000), int(bbox[word_id][3]*height/1000)
cv2.rectangle(image, (x0, y0), (x1, y1), color=color_map[id2label[lb]], thickness=thickness)
if i == 0:
x_center0, y_center0 = int((x0+x1)/2), int((y0+y1)/2)
else:
x_center1, y_center1 = int((x0+x1)/2), int((y0+y1)/2)
cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['group'], thickness=thickness)
x_center0, y_center0 = x_center1, y_center1
if len(pr_relations) > 0:
for pair in pr_relations:
xyxy0 = int(bbox[pair[0]][0]*width/1000), int(bbox[pair[0]][1]*height/1000), int(bbox[pair[0]][2]*width/1000), int(bbox[pair[0]][3]*height/1000)
xyxy1 = int(bbox[pair[1]][0]*width/1000), int(bbox[pair[1]][1]*height/1000), int(bbox[pair[1]][2]*width/1000), int(bbox[pair[1]][3]*height/1000)
x_center0, y_center0 = int((xyxy0[0] + xyxy0[2])/2), int((xyxy0[1] + xyxy0[3])/2)
x_center1, y_center1 = int((xyxy1[0] + xyxy1[2])/2), int((xyxy1[1] + xyxy1[3])/2)
cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['relation'], thickness=thickness)
return image
def get_pairs(json: list, rel_from: str, rel_to: str) -> dict:
outputs = {}
for pair in json:
is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}}
for element in pair:
if element['class'] in (rel_from, rel_to):
is_rel[element['class']]['status'] = 1
is_rel[element['class']]['value'] = element
if all([v['status'] == 1 for _, v in is_rel.items()]):
outputs[is_rel[rel_to]['value']['group_id']] = [is_rel[rel_from]['value']['group_id'], is_rel[rel_to]['value']['group_id']]
return outputs
def get_table_relations(json: list, header_key_pairs: dict, rel_from="key", rel_to="value") -> dict:
list_keys = list(header_key_pairs.keys())
relations = {k: [] for k in list_keys}
for pair in json:
is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}}
for element in pair:
if element['class'] == rel_from and element['group_id'] in list_keys:
is_rel[rel_from]['status'] = 1
is_rel[rel_from]['value'] = element
if element['class'] == rel_to:
is_rel[rel_to]['status'] = 1
is_rel[rel_to]['value'] = element
if all([v['status'] == 1 for _, v in is_rel.items()]):
relations[is_rel[rel_from]['value']['group_id']].append(is_rel[rel_to]['value']['group_id'])
return relations
def get_key2values_relations(key_value_pairs: dict):
triple_linkings = {}
for value_group_id, key_value_pair in key_value_pairs.items():
key_group_id = key_value_pair[0]
if key_group_id not in list(triple_linkings.keys()):
triple_linkings[key_group_id] = []
triple_linkings[key_group_id].append(value_group_id)
return triple_linkings
def merged_token_to_wordgroup(class_words: list, lwords, labels) -> dict:
word_groups = {}
id2class = {i: labels[i] for i in range(len(labels))}
for class_id, lwgroups_in_class in enumerate(class_words):
for ltokens_in_wgroup in lwgroups_in_class:
group_id = ltokens_in_wgroup[0]
ltokens_to_ltexts = [lwords[token] for token in ltokens_in_wgroup]
text_string = get_string(ltokens_to_ltexts)
word_groups[group_id] = {
'group_id': group_id,
'text': text_string,
'class': id2class[class_id],
'tokens': ltokens_in_wgroup
}
return word_groups
def verify_linking_id(word_groups: dict, linking_id: int) -> int:
if linking_id not in list(word_groups):
for wg_id, _word_group in word_groups.items():
if linking_id in _word_group['tokens']:
return wg_id
return linking_id
def matched_wordgroup_relations(word_groups:dict, lrelations: list) -> list:
outputs = []
for pair in lrelations:
wg_from = verify_linking_id(word_groups, pair[0])
wg_to = verify_linking_id(word_groups, pair[1])
try:
outputs.append([word_groups[wg_from], word_groups[wg_to]])
except Exception as e:
logger.info('Not valid pair:', wg_from, wg_to)
return outputs
def export_kvu_outputs(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
word_groups = merged_token_to_wordgroup(class_words, lwords, labels)
linking_pairs = matched_wordgroup_relations(word_groups, lrelations)
header_key = get_pairs(linking_pairs, rel_from='header', rel_to='key') # => {key_group_id: [header_group_id, key_group_id]}
header_value = get_pairs(linking_pairs, rel_from='header', rel_to='value') # => {value_group_id: [header_group_id, value_group_id]}
key_value = get_pairs(linking_pairs, rel_from='key', rel_to='value') # => {value_group_id: [key_group_id, value_group_id]}
# table_relations = get_table_relations(linking_pairs, header_key) # => {key_group_id: [value_group_id1, value_groupid2, ...]}
key2values_relations = get_key2values_relations(key_value) # => {key_group_id: [value_group_id1, value_groupid2, ...]}
triplet_pairs = []
single_pairs = []
table = []
# logger.info('key2values_relations', key2values_relations)
for key_group_id, list_value_group_ids in key2values_relations.items():
if len(list_value_group_ids) == 0: continue
elif len(list_value_group_ids) == 1:
value_group_id = list_value_group_ids[0]
single_pairs.append({word_groups[key_group_id]['text']: {
'text': word_groups[value_group_id]['text'],
'id': value_group_id,
'class': "value"
}})
else:
item = []
for value_group_id in list_value_group_ids:
if value_group_id not in header_value.keys():
header_name_for_value = "non-header"
else:
header_group_id = header_value[value_group_id][0]
header_name_for_value = word_groups[header_group_id]['text']
item.append({
'text': word_groups[value_group_id]['text'],
'header': header_name_for_value,
'id': value_group_id,
'class': 'value'
})
if key_group_id not in list(header_key.keys()):
triplet_pairs.append({
word_groups[key_group_id]['text']: item
})
else:
header_group_id = header_key[key_group_id][0]
header_name_for_key = word_groups[header_group_id]['text']
item.append({
'text': word_groups[key_group_id]['text'],
'header': header_name_for_key,
'id': key_group_id,
'class': 'key'
})
table.append({key_group_id: item})
if len(table) > 0:
table = sorted(table, key=lambda x: list(x.keys())[0])
table = [v for item in table for k, v in item.items()]
outputs = {}
outputs['single'] = sorted(single_pairs, key=lambda x: int(float(list(x.values())[0]['id'])))
outputs['triplet'] = triplet_pairs
outputs['table'] = table
file_path = os.path.join(os.path.dirname(file_path), 'kvu_results', os.path.basename(file_path))
write_to_json(file_path, outputs)
return outputs
# For FI-VAT project
def get_vat_table_information(outputs):
table = []
for single_item in outputs['table']:
item = {k: [] for k in list(vat_dictionary(header=True).keys())}
for cell in single_item:
header_name, score, proceessed_text = vat_standardizer(cell['header'], threshold=0.75, header=True)
if header_name in list(item.keys()):
# item[header_name] = value['text']
item[header_name].append({
'content': cell['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': cell['id']
})
for header_name, value in item.items():
if len(value) == 0:
if header_name in ("Số lượng", "Doanh số mua chưa có thuế"):
item[header_name] = '0'
else:
item[header_name] = None
continue
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
item = post_process_for_item(item)
if item["Mặt hàng"] == None:
continue
table.append(item)
return table
def get_vat_information(outputs):
# VAT Information
single_pairs = {k: [] for k in list(vat_dictionary(header=False).keys())}
for pair in outputs['single']:
for raw_key_name, value in pair.items():
key_name, score, proceessed_text = vat_standardizer(raw_key_name, threshold=0.8, header=False)
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
if key_name in list(single_pairs.keys()):
single_pairs[key_name].append({
'content': value['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': value['id'],
})
for triplet in outputs['triplet']:
for key, value_list in triplet.items():
if len(value_list) == 1:
key_name, score, proceessed_text = vat_standardizer(key, threshold=0.8, header=False)
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
if key_name in list(single_pairs.keys()):
single_pairs[key_name].append({
'content': value_list[0]['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': value_list[0]['id']
})
for pair in value_list:
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
if key_name in list(single_pairs.keys()):
single_pairs[key_name].append({
'content': pair['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': pair['id']
})
for table_row in outputs['table']:
for pair in table_row:
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
if key_name in list(single_pairs.keys()):
single_pairs[key_name].append({
'content': pair['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': pair['id']
})
return single_pairs
def post_process_vat_information(single_pairs):
vat_outputs = {k: None for k in list(single_pairs)}
for key_name, list_potential_value in single_pairs.items():
if key_name in ("Ngày, tháng, năm lập hóa đơn"):
if len(list_potential_value) == 1:
vat_outputs[key_name] = list_potential_value[0]['content']
else:
date_time = {'day': 'dd', 'month': 'mm', 'year': 'yyyy'}
for value in list_potential_value:
date_time[value['processed_key_name']] = re.sub("[^0-9]", "", value['content'])
vat_outputs[key_name] = f"{date_time['day']}/{date_time['month']}/{date_time['year']}"
else:
if len(list_potential_value) == 0: continue
if key_name in ("Mã số thuế người bán"):
selected_value = min(list_potential_value, key=lambda x: x['token_id']) # Get first tax code
# tax_code_raw = selected_value['content'].replace(' ', '')
tax_code_raw = selected_value['content']
if len(tax_code_raw.replace(' ', '')) not in (10, 13): # to remove the first number dupicated
tax_code_raw = tax_code_raw.split(' ')
tax_code_raw = sorted(tax_code_raw, key=lambda x: len(x), reverse=True)[0]
vat_outputs[key_name] = tax_code_raw.replace(' ', '')
else:
selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score
vat_outputs[key_name] = selected_value['content']
return vat_outputs
def export_kvu_for_VAT_invoice(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
vat_outputs = {}
outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels)
# List of items in table
table = get_vat_table_information(outputs)
# VAT Information
single_pairs = get_vat_information(outputs)
vat_outputs = post_process_vat_information(single_pairs)
# Combine VAT information and table
vat_outputs['table'] = table
write_to_json(file_path, vat_outputs)
return vat_outputs
# For SBT project
def get_ap_table_information(outputs):
table = []
for single_item in outputs['table']:
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
for cell in single_item:
header_name, score, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=True)
# logger.info(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}")
if header_name in list(item.keys()):
item[header_name].append({
'content': cell['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': cell['id']
})
for header_name, value in item.items():
if len(value) == 0:
item[header_name] = None
continue
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
table.append(item)
return table
def get_ap_triplet_information(outputs):
triplet_pairs = []
for single_item in outputs['triplet']:
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
is_item_valid = 0
for key_name, list_value in single_item.items():
for value in list_value:
if value['header'] == "non-header":
continue
header_name, score, proceessed_text = ap_standardizer(value['header'], threshold=0.8, header=True)
if header_name in list(item.keys()):
is_item_valid = 1
item[header_name].append({
'content': value['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': value['id']
})
if is_item_valid == 1:
for header_name, value in item.items():
if len(value) == 0:
item[header_name] = None
continue
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
item['productname'] = key_name
# triplet_pairs.append({key_name: new_item})
triplet_pairs.append(item)
return triplet_pairs
def get_ap_information(outputs):
single_pairs = {k: [] for k in list(ap_dictionary(header=False).keys())}
for pair in outputs['single']:
for key_name, value in pair.items():
key_name, score, proceessed_text = ap_standardizer(key_name, threshold=0.8, header=False)
# logger.info(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
if key_name in list(single_pairs):
single_pairs[key_name].append({
'content': value['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': value['id']
})
ap_outputs = {k: None for k in list(single_pairs)}
for key_name, list_potential_value in single_pairs.items():
if len(list_potential_value) == 0: continue
selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score
ap_outputs[key_name] = selected_value['content']
return ap_outputs
def export_kvu_for_SDSAP(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels)
# List of items in table
table = get_ap_table_information(outputs)
triplet_pairs = get_ap_triplet_information(outputs)
table = table + triplet_pairs
ap_outputs = get_ap_information(outputs)
ap_outputs['table'] = table
# ap_outputs['triplet'] = triplet_pairs
write_to_json(file_path, ap_outputs)