262 lines
8.0 KiB
Python
Executable File
262 lines
8.0 KiB
Python
Executable File
import os
|
|
import json
|
|
import time
|
|
|
|
from common import json2xml
|
|
from common.json2xml import convert_key_names, replace_xml_values
|
|
from common.utils_kvu.split_docs import split_docs, merge_sbt_output
|
|
|
|
# from api.OCRBase.prediction import predict as ocr_predict
|
|
# from api.Kie_Invoice_AP.prediction_sap import predict
|
|
# from api.Kie_Invoice_AP.prediction_fi import predict_fi
|
|
# from api.manulife.predict_manulife import predict as predict_manulife
|
|
from api.sdsap_sbt.prediction_sbt import predict as predict_sbt
|
|
|
|
os.environ['PYTHONPATH'] = '/home/thucpd/thucpd/cope2n-ai/cope2n-ai/'
|
|
|
|
def check_label_exists(array, target_label):
|
|
for obj in array:
|
|
if obj["label"] == target_label:
|
|
return True # Label exists in the array
|
|
return False # Label does not exist in the array
|
|
|
|
def compile_output(list_url):
|
|
"""_summary_
|
|
|
|
Args:
|
|
pdf_extracted (list): list: [{
|
|
"1": url},{"2": url},
|
|
...]
|
|
Raises:
|
|
NotImplementedError: _description_
|
|
|
|
Returns:
|
|
dict: output compiled
|
|
"""
|
|
|
|
results = {
|
|
"model":{
|
|
"name":"Invoice",
|
|
"confidence": 1.0,
|
|
"type": "finance/invoice",
|
|
"isValid": True,
|
|
"shape": "letter",
|
|
}
|
|
}
|
|
compile_outputs = []
|
|
compiled = []
|
|
for page in list_url:
|
|
output_model = predict(page['page_number'], page['file_url'])
|
|
for field in output_model['fields']:
|
|
if field['value'] != "" and not check_label_exists(compiled, field['label']):
|
|
element = {
|
|
'label': field['label'],
|
|
'value': field['value'],
|
|
}
|
|
compiled.append(element)
|
|
elif field['label'] == 'table' and check_label_exists(compiled, "table"):
|
|
for index, obj in enumerate(compiled):
|
|
if obj['label'] == 'table':
|
|
compiled[index]['value'].append(field['value'])
|
|
compile_output = {
|
|
'page_index': page['page_number'],
|
|
'request_file_id': page['request_file_id'],
|
|
'fields': output_model['fields']
|
|
}
|
|
compile_outputs.append(compile_output)
|
|
results['combine_results'] = compiled
|
|
results['pages'] = compile_outputs
|
|
return results
|
|
|
|
def update_null_values(kvu_result, next_kvu_result):
|
|
for key, value in kvu_result.items():
|
|
if value is None and next_kvu_result.get(key) is not None:
|
|
kvu_result[key] = next_kvu_result[key]
|
|
|
|
def replace_empty_null_values(my_dict):
|
|
for key, value in my_dict.items():
|
|
if value == '':
|
|
my_dict[key] = None
|
|
return my_dict
|
|
|
|
def compile_output_fi(list_url):
|
|
"""_summary_
|
|
|
|
Args:
|
|
pdf_extracted (list): list: [{
|
|
"1": url},{"2": url},
|
|
...]
|
|
Raises:
|
|
NotImplementedError: _description_
|
|
|
|
Returns:
|
|
dict: output compiled
|
|
"""
|
|
|
|
results = {
|
|
"model":{
|
|
"name":"Invoice",
|
|
"confidence": 1.0,
|
|
"type": "finance/invoice",
|
|
"isValid": True,
|
|
"shape": "letter",
|
|
}
|
|
}
|
|
# Loop through the list_url to update kvu_result
|
|
for i in range(len(list_url) - 1):
|
|
page = list_url[i]
|
|
next_page = list_url[i + 1]
|
|
kvu_result, output_kie = predict_fi(page['page_number'], page['file_url'])
|
|
next_kvu_result, next_output_kie = predict_fi(next_page['page_number'], next_page['file_url'])
|
|
|
|
update_null_values(kvu_result, next_kvu_result)
|
|
output_kie = replace_empty_null_values(output_kie)
|
|
next_output_kie = replace_empty_null_values(next_output_kie)
|
|
update_null_values(output_kie, next_output_kie)
|
|
|
|
# Handle the last item in the list_url
|
|
if list_url:
|
|
page = list_url[-1]
|
|
kvu_result, output_kie = predict_fi(page['page_number'], page['file_url'])
|
|
|
|
converted_dict = convert_key_names(kvu_result)
|
|
converted_dict.update(convert_key_names(output_kie))
|
|
output_fi = replace_xml_values(json2xml.xml_template3, converted_dict)
|
|
field_fi = {
|
|
"xml": output_fi,
|
|
}
|
|
results['combine_results'] = field_fi
|
|
# results['combine_results'] = converted_dict
|
|
# results['combine_results_kie'] = output_kie
|
|
return results
|
|
|
|
def compile_output_ocr_base(list_url):
|
|
"""Compile output of OCRBase
|
|
|
|
Args:
|
|
list_url (list): List string url of image
|
|
|
|
Returns:
|
|
dict: dict of output
|
|
"""
|
|
|
|
results = {
|
|
"model":{
|
|
"name":"OCRBase",
|
|
"confidence": 1.0,
|
|
"type": "ocrbase",
|
|
"isValid": True,
|
|
"shape": "letter",
|
|
}
|
|
}
|
|
compile_outputs = []
|
|
for page in list_url:
|
|
output_model = ocr_predict(page['page_number'], page['file_url'])
|
|
compile_output = {
|
|
'page_index': page['page_number'],
|
|
'request_file_id': page['request_file_id'],
|
|
'fields': output_model['fields']
|
|
}
|
|
compile_outputs.append(compile_output)
|
|
results['pages'] = compile_outputs
|
|
return results
|
|
|
|
def compile_output_manulife(list_url):
|
|
"""_summary_
|
|
|
|
Args:
|
|
pdf_extracted (list): list: [{
|
|
"1": url},{"2": url},
|
|
...]
|
|
Raises:
|
|
NotImplementedError: _description_
|
|
|
|
Returns:
|
|
dict: output compiled
|
|
"""
|
|
|
|
results = {
|
|
"model":{
|
|
"name":"Invoice",
|
|
"confidence": 1.0,
|
|
"type": "finance/invoice",
|
|
"isValid": True,
|
|
"shape": "letter",
|
|
}
|
|
}
|
|
|
|
outputs = []
|
|
for page in list_url:
|
|
output_model = predict_manulife(page['page_number'], page['file_url']) # gotta be predict_manulife(), for the time being, this function is not avaible, we just leave a dummy function here instead
|
|
print("output_model", output_model)
|
|
outputs.append(output_model)
|
|
print("outputs", outputs)
|
|
documents = split_docs(outputs)
|
|
print("documents", documents)
|
|
results = {
|
|
"total_pages": len(list_url),
|
|
"ocr_num_pages": len(list_url),
|
|
"document": documents
|
|
}
|
|
return results
|
|
|
|
def compile_output_sbt(list_url):
|
|
"""_summary_
|
|
|
|
Args:
|
|
pdf_extracted (list): list: [{
|
|
"1": url},{"2": url},
|
|
...]
|
|
Raises:
|
|
NotImplementedError: _description_
|
|
|
|
Returns:
|
|
dict: output compiled
|
|
"""
|
|
|
|
inference_profile = {}
|
|
|
|
results = {
|
|
"model":{
|
|
"name":"Invoice",
|
|
"confidence": 1.0,
|
|
"type": "finance/invoice",
|
|
"isValid": True,
|
|
"shape": "letter",
|
|
}
|
|
}
|
|
|
|
|
|
outputs = []
|
|
start = time.time()
|
|
pages_predict_time = []
|
|
for page in list_url:
|
|
output_model = predict_sbt(page['page_number'], page['file_url'])
|
|
pages_predict_time.append(time.time())
|
|
if "doc_type" in page:
|
|
output_model['doc_type'] = page['doc_type']
|
|
outputs.append(output_model)
|
|
start_postprocess = time.time()
|
|
documents = merge_sbt_output(outputs)
|
|
inference_profile["postprocess"] = [start_postprocess, time.time()]
|
|
inference_profile["inference"] = [start, pages_predict_time]
|
|
results = {
|
|
"total_pages": len(list_url),
|
|
"ocr_num_pages": len(list_url),
|
|
"document": documents,
|
|
"inference_profile": inference_profile
|
|
}
|
|
return results
|
|
|
|
|
|
def main():
|
|
"""
|
|
main function
|
|
"""
|
|
list_url = [{"file_url": "https://www.irs.gov/pub/irs-pdf/fw9.pdf", "page_number": 1, "request_file_id": 1}, ...]
|
|
results = compile_output(list_url)
|
|
with open('output.json', 'w', encoding='utf-8') as file:
|
|
json.dump(results, file, ensure_ascii=False, indent=4)
|
|
|
|
if __name__ == "__main__":
|
|
main() |