sbt-idp/cope2n-ai-fi/modules/sdsvkie/eval_with_api.py

276 lines
6.9 KiB
Python
Raw Normal View History

2023-12-12 08:14:54 +00:00
import json
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import requests
import tqdm
from sdsvkie.utils.eval_kie import eval_kie
from sdsvkie.utils.io_file import read_json, write_json
logging.basicConfig(
level=logging.INFO,
# format=""
)
logger = logging.getLogger()
HEADERS = {
'accept': 'application/json',
'Authorization': 'Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2OTA2ODk4ODcsInVzZXJuYW1lIjoiYWRtaW4ifQ.Oybpc9tBsN35vCn3jzekkABDQKJT6yO1aBBJ4rMNln0'
}
URL = "http://107.120.133.27:8082/predict/image"
def run(
data_dir: str,
url: str,
gt_path: str,
field_pred_file: str,
samples: Union[int, None] = None,
):
files = get_files(data_dir, recursive=False, limit=samples)
preds = predict(url, files)
## process for table
# table_eval_result = {}
# table_preds = get_table_preds(preds)
# table_eval_result =
# process for seller, buyer, ...
field_eval_result = {}
# field_preds = get_field_preds_from_api(api_preds=preds)
field_preds = get_field_preds_from_file(pred_file=field_pred_file)
classes = get_classes(preds=field_preds)
if len(classes) == 0:
raise Exception("Can not get the classes list")
field_eval_result = eval(
gt=gt_path,
pred=field_preds,
classes=classes,
classes_ignore=['other', 'table']
)
print(field_eval_result)
## combine result
combine_result = {}
# combine_result = combine_result(table_res=table_eval_result, field_res=field_eval_result)
print_result(
data_path=data_dir,
num_samples=len(list(field_preds.keys())),
target_level=0.05,
result=1.0, # edit here
)
return combine_result
def print_result(
data_path: str,
num_samples:int,
target_level: float,
result: float,
metric: str = "NLD",
avg_time: float = 1.6363
):
print(f"Path of validation dataset: {data_path}\n"
+ f"Number of validation dataset: {num_samples}\n"
+ f"Evaluation metric: {metric}\n"
+ f"Target level: {target_level}\n"
+ f"Archieved level: {result}\n"
+ f"Average time: {avg_time}\n"
+ f"Verification result: {'PASS' if result > target_level else 'FAILED'}"
)
def get_field_preds_from_api(api_preds: str) -> dict:
field_preds = get_fields_preds(api_preds)
field_preds = combine_to_single_file(field_preds)
return field_preds
def get_field_preds_from_file(pred_file: str) -> dict:
"""
Get predictions from json file
"""
field_preds = read_json(pred_file)
return field_preds
def get_fields_preds(preds: List[Dict]):
preds = [
{item['file_path']: format_output_api(item['response_dict'])}
for item in preds
]
return preds
def combine_result(table_res: Dict, field_res: Dict):
return {}
def _str2dict(text: str) -> Dict:
try:
data = json.loads(text)
except Exception as err:
logger.error(f"{err} - data: {text}")
data = {}
return data
def predict_one_file(url: str, file: Union[str, Path]) -> Dict:
"""
Output format:
{
file_path: path of file
response_dict:
}
"""
if isinstance(file, str):
file = Path(file)
payload = {}
filename = file.name
files = [
(
'file',
(
filename,
open(str(file), 'rb'),
'application/pdf'
)
)
]
# logger.info(f"Files: {file}")
response = requests.request(
"POST", url, headers=HEADERS, data=payload, files=files)
response_dict = _str2dict(response.text)
return {
"file_path": str(file),
"pred_data": response_dict
}
def predict(url: str, files: List[Union[str, Path]]) -> List[Dict]:
"""
List of {'file_path', 'response_dict'}
"""
preds = []
for idx, file in tqdm.tqdm(enumerate(files)):
try:
pred = predict_one_file(url, file)
preds.append(pred)
except:
logger.info(f"Error at file: {file}")
return preds
def get_files(data_dir: str, recursive: bool = False, limit: Union[int, None] = None) -> List[Union[Path, str]]:
if recursive:
files = Path(data_dir).rglob("*")
else:
files = Path(data_dir).glob("*")
if limit:
files = list(files)[:limit]
return files
def _stem_filename(filename: str) -> str:
"""
Stem a file path: x/y.txt -> y
"""
return Path(filename).stem
def format_output_api(output_api: Dict, skip_fields=['table']) -> Dict:
if "pages" not in output_api:
return {}
pages = output_api['pages']
result = {}
for page in pages:
fields = page['fields']
for field_item in fields:
field_label, field_value = field_item['label'], field_item['value']
if field_label in result or field_label in skip_fields:
continue
result[field_label] = field_value
return result
def combine_to_single_file(preds: List[Dict]) -> None:
if len(preds) == 0:
return {}
combined_data = {
_stem_filename(item["filename"]): item["pred_data"]
for item in preds
}
return combined_data
def eval(
gt: Union[str, Dict],
pred: Union[str, Dict],
classes: List[str],
classes_ignore: List[str] = []
) -> Dict:
eval_res = eval_kie(
gt_e2e_path=gt,
pred_e2e_path=pred,
kie_labels=classes,
skip_labels=classes_ignore
)
return eval_res
def get_classes(preds: Dict) -> List[str]:
classes = []
for k, v in preds.items():
if v:
classes = list(v.keys())
break
return classes
def test():
import requests
url = "http://107.120.133.27:8082/predict/image"
payload = {}
files = [
('file', ('(1 of 19)_HOADON_1C23TYY_50.pdf', open(
'/mnt/ssd1T/hoanglv/Projects/KIE/DATA/dev_model/Invoice/processed/test/PV2/final/all_in/(1 of 19)_HOADON_1C23TYY_50.pdf', 'rb'), 'application/pdf'))
]
headers = {
'accept': 'application/json',
'Authorization': 'Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2OTA2ODk4ODcsInVzZXJuYW1lIjoiYWRtaW4ifQ.Oybpc9tBsN35vCn3jzekkABDQKJT6yO1aBBJ4rMNln0'
}
response = requests.request(
"POST", url, headers=headers, data=payload, files=files)
print(response.text)
# print(json.loa ds(response.text))
if __name__ == "__main__":
limit = 5
run(
data_dir="/mnt/ssd1T/hoanglv/Projects/KIE/DATA/dev_model/Invoice/processed/test/PV2/final/all_in",
url=URL,
gt_path="/mnt/ssd1T/hoanglv/Projects/KIE/DATA/dev_model/Invoice/processed/test/PV2/final/all_in.json",
field_pred_file="/mnt/ssd1T/hoanglv/Projects/KIE/sdsvkie/workdirs/invoice/06062023/invoice_all_in_final_e2e_21072023_5.json",
samples=limit
)
# test()