276 lines
6.9 KiB
Python
Executable File
276 lines
6.9 KiB
Python
Executable File
import json
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
import requests
|
|
import tqdm
|
|
from sdsvkie.utils.eval_kie import eval_kie
|
|
from sdsvkie.utils.io_file import read_json, write_json
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
# format=""
|
|
|
|
)
|
|
logger = logging.getLogger()
|
|
|
|
|
|
HEADERS = {
|
|
'accept': 'application/json',
|
|
'Authorization': 'Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2OTA2ODk4ODcsInVzZXJuYW1lIjoiYWRtaW4ifQ.Oybpc9tBsN35vCn3jzekkABDQKJT6yO1aBBJ4rMNln0'
|
|
}
|
|
|
|
URL = "http://107.120.133.27:8082/predict/image"
|
|
|
|
|
|
def run(
|
|
data_dir: str,
|
|
url: str,
|
|
gt_path: str,
|
|
field_pred_file: str,
|
|
samples: Union[int, None] = None,
|
|
|
|
):
|
|
|
|
files = get_files(data_dir, recursive=False, limit=samples)
|
|
preds = predict(url, files)
|
|
|
|
## process for table
|
|
# table_eval_result = {}
|
|
# table_preds = get_table_preds(preds)
|
|
# table_eval_result =
|
|
|
|
# process for seller, buyer, ...
|
|
field_eval_result = {}
|
|
# field_preds = get_field_preds_from_api(api_preds=preds)
|
|
field_preds = get_field_preds_from_file(pred_file=field_pred_file)
|
|
classes = get_classes(preds=field_preds)
|
|
if len(classes) == 0:
|
|
raise Exception("Can not get the classes list")
|
|
field_eval_result = eval(
|
|
gt=gt_path,
|
|
pred=field_preds,
|
|
classes=classes,
|
|
classes_ignore=['other', 'table']
|
|
)
|
|
print(field_eval_result)
|
|
## combine result
|
|
combine_result = {}
|
|
# combine_result = combine_result(table_res=table_eval_result, field_res=field_eval_result)
|
|
|
|
|
|
print_result(
|
|
data_path=data_dir,
|
|
num_samples=len(list(field_preds.keys())),
|
|
target_level=0.05,
|
|
result=1.0, # edit here
|
|
)
|
|
return combine_result
|
|
|
|
|
|
|
|
def print_result(
|
|
data_path: str,
|
|
num_samples:int,
|
|
target_level: float,
|
|
result: float,
|
|
metric: str = "NLD",
|
|
avg_time: float = 1.6363
|
|
):
|
|
|
|
print(f"Path of validation dataset: {data_path}\n"
|
|
+ f"Number of validation dataset: {num_samples}\n"
|
|
+ f"Evaluation metric: {metric}\n"
|
|
+ f"Target level: {target_level}\n"
|
|
+ f"Archieved level: {result}\n"
|
|
+ f"Average time: {avg_time}\n"
|
|
+ f"Verification result: {'PASS' if result > target_level else 'FAILED'}"
|
|
)
|
|
|
|
|
|
def get_field_preds_from_api(api_preds: str) -> dict:
|
|
field_preds = get_fields_preds(api_preds)
|
|
field_preds = combine_to_single_file(field_preds)
|
|
return field_preds
|
|
|
|
def get_field_preds_from_file(pred_file: str) -> dict:
|
|
"""
|
|
Get predictions from json file
|
|
"""
|
|
field_preds = read_json(pred_file)
|
|
return field_preds
|
|
|
|
|
|
|
|
def get_fields_preds(preds: List[Dict]):
|
|
preds = [
|
|
{item['file_path']: format_output_api(item['response_dict'])}
|
|
for item in preds
|
|
]
|
|
return preds
|
|
|
|
def combine_result(table_res: Dict, field_res: Dict):
|
|
return {}
|
|
|
|
def _str2dict(text: str) -> Dict:
|
|
try:
|
|
data = json.loads(text)
|
|
except Exception as err:
|
|
logger.error(f"{err} - data: {text}")
|
|
data = {}
|
|
return data
|
|
|
|
|
|
def predict_one_file(url: str, file: Union[str, Path]) -> Dict:
|
|
"""
|
|
Output format:
|
|
{
|
|
file_path: path of file
|
|
response_dict:
|
|
}
|
|
"""
|
|
if isinstance(file, str):
|
|
file = Path(file)
|
|
|
|
payload = {}
|
|
filename = file.name
|
|
files = [
|
|
(
|
|
'file',
|
|
(
|
|
filename,
|
|
open(str(file), 'rb'),
|
|
'application/pdf'
|
|
)
|
|
)
|
|
]
|
|
# logger.info(f"Files: {file}")
|
|
response = requests.request(
|
|
"POST", url, headers=HEADERS, data=payload, files=files)
|
|
|
|
response_dict = _str2dict(response.text)
|
|
|
|
return {
|
|
"file_path": str(file),
|
|
"pred_data": response_dict
|
|
}
|
|
|
|
|
|
def predict(url: str, files: List[Union[str, Path]]) -> List[Dict]:
|
|
"""
|
|
List of {'file_path', 'response_dict'}
|
|
"""
|
|
|
|
preds = []
|
|
for idx, file in tqdm.tqdm(enumerate(files)):
|
|
try:
|
|
pred = predict_one_file(url, file)
|
|
preds.append(pred)
|
|
except:
|
|
logger.info(f"Error at file: {file}")
|
|
return preds
|
|
|
|
|
|
def get_files(data_dir: str, recursive: bool = False, limit: Union[int, None] = None) -> List[Union[Path, str]]:
|
|
if recursive:
|
|
files = Path(data_dir).rglob("*")
|
|
else:
|
|
files = Path(data_dir).glob("*")
|
|
|
|
if limit:
|
|
files = list(files)[:limit]
|
|
return files
|
|
|
|
|
|
def _stem_filename(filename: str) -> str:
|
|
"""
|
|
Stem a file path: x/y.txt -> y
|
|
"""
|
|
return Path(filename).stem
|
|
|
|
|
|
def format_output_api(output_api: Dict, skip_fields=['table']) -> Dict:
|
|
if "pages" not in output_api:
|
|
return {}
|
|
pages = output_api['pages']
|
|
|
|
result = {}
|
|
for page in pages:
|
|
fields = page['fields']
|
|
for field_item in fields:
|
|
field_label, field_value = field_item['label'], field_item['value']
|
|
if field_label in result or field_label in skip_fields:
|
|
continue
|
|
result[field_label] = field_value
|
|
return result
|
|
|
|
|
|
def combine_to_single_file(preds: List[Dict]) -> None:
|
|
if len(preds) == 0:
|
|
return {}
|
|
combined_data = {
|
|
_stem_filename(item["filename"]): item["pred_data"]
|
|
for item in preds
|
|
}
|
|
return combined_data
|
|
|
|
|
|
def eval(
|
|
gt: Union[str, Dict],
|
|
pred: Union[str, Dict],
|
|
classes: List[str],
|
|
classes_ignore: List[str] = []
|
|
) -> Dict:
|
|
eval_res = eval_kie(
|
|
gt_e2e_path=gt,
|
|
pred_e2e_path=pred,
|
|
kie_labels=classes,
|
|
skip_labels=classes_ignore
|
|
)
|
|
return eval_res
|
|
|
|
|
|
def get_classes(preds: Dict) -> List[str]:
|
|
classes = []
|
|
for k, v in preds.items():
|
|
if v:
|
|
classes = list(v.keys())
|
|
break
|
|
return classes
|
|
|
|
|
|
def test():
|
|
import requests
|
|
|
|
url = "http://107.120.133.27:8082/predict/image"
|
|
|
|
payload = {}
|
|
files = [
|
|
('file', ('(1 of 19)_HOADON_1C23TYY_50.pdf', open(
|
|
'/mnt/ssd1T/hoanglv/Projects/KIE/DATA/dev_model/Invoice/processed/test/PV2/final/all_in/(1 of 19)_HOADON_1C23TYY_50.pdf', 'rb'), 'application/pdf'))
|
|
]
|
|
headers = {
|
|
'accept': 'application/json',
|
|
'Authorization': 'Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2OTA2ODk4ODcsInVzZXJuYW1lIjoiYWRtaW4ifQ.Oybpc9tBsN35vCn3jzekkABDQKJT6yO1aBBJ4rMNln0'
|
|
}
|
|
|
|
response = requests.request(
|
|
"POST", url, headers=headers, data=payload, files=files)
|
|
|
|
print(response.text)
|
|
# print(json.loa ds(response.text))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
limit = 5
|
|
run(
|
|
data_dir="/mnt/ssd1T/hoanglv/Projects/KIE/DATA/dev_model/Invoice/processed/test/PV2/final/all_in",
|
|
url=URL,
|
|
gt_path="/mnt/ssd1T/hoanglv/Projects/KIE/DATA/dev_model/Invoice/processed/test/PV2/final/all_in.json",
|
|
field_pred_file="/mnt/ssd1T/hoanglv/Projects/KIE/sdsvkie/workdirs/invoice/06062023/invoice_all_in_final_e2e_21072023_5.json",
|
|
samples=limit
|
|
)
|
|
|
|
# test()
|