sbt-idp/cope2n-ai-fi/common/process_pdf.py

262 lines
8.0 KiB
Python
Executable File

import os
import json
import time
from common import json2xml
from common.json2xml import convert_key_names, replace_xml_values
from common.utils_kvu.split_docs import split_docs, merge_sbt_output
# from api.OCRBase.prediction import predict as ocr_predict
# from api.Kie_Invoice_AP.prediction_sap import predict
# from api.Kie_Invoice_AP.prediction_fi import predict_fi
# from api.manulife.predict_manulife import predict as predict_manulife
from api.sdsap_sbt.prediction_sbt import predict as predict_sbt
os.environ['PYTHONPATH'] = '/home/thucpd/thucpd/cope2n-ai/cope2n-ai/'
def check_label_exists(array, target_label):
for obj in array:
if obj["label"] == target_label:
return True # Label exists in the array
return False # Label does not exist in the array
def compile_output(list_url):
"""_summary_
Args:
pdf_extracted (list): list: [{
"1": url},{"2": url},
...]
Raises:
NotImplementedError: _description_
Returns:
dict: output compiled
"""
results = {
"model":{
"name":"Invoice",
"confidence": 1.0,
"type": "finance/invoice",
"isValid": True,
"shape": "letter",
}
}
compile_outputs = []
compiled = []
for page in list_url:
output_model = predict(page['page_number'], page['file_url'])
for field in output_model['fields']:
if field['value'] != "" and not check_label_exists(compiled, field['label']):
element = {
'label': field['label'],
'value': field['value'],
}
compiled.append(element)
elif field['label'] == 'table' and check_label_exists(compiled, "table"):
for index, obj in enumerate(compiled):
if obj['label'] == 'table':
compiled[index]['value'].append(field['value'])
compile_output = {
'page_index': page['page_number'],
'request_file_id': page['request_file_id'],
'fields': output_model['fields']
}
compile_outputs.append(compile_output)
results['combine_results'] = compiled
results['pages'] = compile_outputs
return results
def update_null_values(kvu_result, next_kvu_result):
for key, value in kvu_result.items():
if value is None and next_kvu_result.get(key) is not None:
kvu_result[key] = next_kvu_result[key]
def replace_empty_null_values(my_dict):
for key, value in my_dict.items():
if value == '':
my_dict[key] = None
return my_dict
def compile_output_fi(list_url):
"""_summary_
Args:
pdf_extracted (list): list: [{
"1": url},{"2": url},
...]
Raises:
NotImplementedError: _description_
Returns:
dict: output compiled
"""
results = {
"model":{
"name":"Invoice",
"confidence": 1.0,
"type": "finance/invoice",
"isValid": True,
"shape": "letter",
}
}
# Loop through the list_url to update kvu_result
for i in range(len(list_url) - 1):
page = list_url[i]
next_page = list_url[i + 1]
kvu_result, output_kie = predict_fi(page['page_number'], page['file_url'])
next_kvu_result, next_output_kie = predict_fi(next_page['page_number'], next_page['file_url'])
update_null_values(kvu_result, next_kvu_result)
output_kie = replace_empty_null_values(output_kie)
next_output_kie = replace_empty_null_values(next_output_kie)
update_null_values(output_kie, next_output_kie)
# Handle the last item in the list_url
if list_url:
page = list_url[-1]
kvu_result, output_kie = predict_fi(page['page_number'], page['file_url'])
converted_dict = convert_key_names(kvu_result)
converted_dict.update(convert_key_names(output_kie))
output_fi = replace_xml_values(json2xml.xml_template3, converted_dict)
field_fi = {
"xml": output_fi,
}
results['combine_results'] = field_fi
# results['combine_results'] = converted_dict
# results['combine_results_kie'] = output_kie
return results
def compile_output_ocr_base(list_url):
"""Compile output of OCRBase
Args:
list_url (list): List string url of image
Returns:
dict: dict of output
"""
results = {
"model":{
"name":"OCRBase",
"confidence": 1.0,
"type": "ocrbase",
"isValid": True,
"shape": "letter",
}
}
compile_outputs = []
for page in list_url:
output_model = ocr_predict(page['page_number'], page['file_url'])
compile_output = {
'page_index': page['page_number'],
'request_file_id': page['request_file_id'],
'fields': output_model['fields']
}
compile_outputs.append(compile_output)
results['pages'] = compile_outputs
return results
def compile_output_manulife(list_url):
"""_summary_
Args:
pdf_extracted (list): list: [{
"1": url},{"2": url},
...]
Raises:
NotImplementedError: _description_
Returns:
dict: output compiled
"""
results = {
"model":{
"name":"Invoice",
"confidence": 1.0,
"type": "finance/invoice",
"isValid": True,
"shape": "letter",
}
}
outputs = []
for page in list_url:
output_model = predict_manulife(page['page_number'], page['file_url']) # gotta be predict_manulife(), for the time being, this function is not avaible, we just leave a dummy function here instead
print("output_model", output_model)
outputs.append(output_model)
print("outputs", outputs)
documents = split_docs(outputs)
print("documents", documents)
results = {
"total_pages": len(list_url),
"ocr_num_pages": len(list_url),
"document": documents
}
return results
def compile_output_sbt(list_url, metadata):
"""_summary_
Args:
pdf_extracted (list): list: [{
"1": url},{"2": url},
...]
Raises:
NotImplementedError: _description_
Returns:
dict: output compiled
"""
inference_profile = {}
results = {
"model":{
"name":"Invoice",
"confidence": 1.0,
"type": "finance/invoice",
"isValid": True,
"shape": "letter",
}
}
outputs = []
start = time.time()
pages_predict_time = []
for page in list_url:
output_model = predict_sbt(page['page_number'], page['file_url'], metadata)
pages_predict_time.append(time.time())
if "doc_type" in page:
output_model['doc_type'] = page['doc_type']
outputs.append(output_model)
start_postprocess = time.time()
documents = merge_sbt_output(outputs)
inference_profile["postprocess"] = [start_postprocess, time.time()]
inference_profile["inference"] = [start, pages_predict_time]
results = {
"total_pages": len(list_url),
"ocr_num_pages": len(list_url),
"document": documents,
"inference_profile": inference_profile
}
return results
def main():
"""
main function
"""
list_url = [{"file_url": "https://www.irs.gov/pub/irs-pdf/fw9.pdf", "page_number": 1, "request_file_id": 1}, ...]
results = compile_output(list_url)
with open('output.json', 'w', encoding='utf-8') as file:
json.dump(results, file, ensure_ascii=False, indent=4)
if __name__ == "__main__":
main()