import json import logging from pathlib import Path from typing import Any, Dict, List, Optional, Union import requests import tqdm from sdsvkie.utils.eval_kie import eval_kie from sdsvkie.utils.io_file import read_json, write_json logging.basicConfig( level=logging.INFO, # format="" ) logger = logging.getLogger() HEADERS = { 'accept': 'application/json', 'Authorization': 'Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2OTA2ODk4ODcsInVzZXJuYW1lIjoiYWRtaW4ifQ.Oybpc9tBsN35vCn3jzekkABDQKJT6yO1aBBJ4rMNln0' } URL = "http://107.120.133.27:8082/predict/image" def run( data_dir: str, url: str, gt_path: str, field_pred_file: str, samples: Union[int, None] = None, ): files = get_files(data_dir, recursive=False, limit=samples) preds = predict(url, files) ## process for table # table_eval_result = {} # table_preds = get_table_preds(preds) # table_eval_result = # process for seller, buyer, ... field_eval_result = {} # field_preds = get_field_preds_from_api(api_preds=preds) field_preds = get_field_preds_from_file(pred_file=field_pred_file) classes = get_classes(preds=field_preds) if len(classes) == 0: raise Exception("Can not get the classes list") field_eval_result = eval( gt=gt_path, pred=field_preds, classes=classes, classes_ignore=['other', 'table'] ) print(field_eval_result) ## combine result combine_result = {} # combine_result = combine_result(table_res=table_eval_result, field_res=field_eval_result) print_result( data_path=data_dir, num_samples=len(list(field_preds.keys())), target_level=0.05, result=1.0, # edit here ) return combine_result def print_result( data_path: str, num_samples:int, target_level: float, result: float, metric: str = "NLD", avg_time: float = 1.6363 ): print(f"Path of validation dataset: {data_path}\n" + f"Number of validation dataset: {num_samples}\n" + f"Evaluation metric: {metric}\n" + f"Target level: {target_level}\n" + f"Archieved level: {result}\n" + f"Average time: {avg_time}\n" + f"Verification result: {'PASS' if result > target_level else 'FAILED'}" ) def get_field_preds_from_api(api_preds: str) -> dict: field_preds = get_fields_preds(api_preds) field_preds = combine_to_single_file(field_preds) return field_preds def get_field_preds_from_file(pred_file: str) -> dict: """ Get predictions from json file """ field_preds = read_json(pred_file) return field_preds def get_fields_preds(preds: List[Dict]): preds = [ {item['file_path']: format_output_api(item['response_dict'])} for item in preds ] return preds def combine_result(table_res: Dict, field_res: Dict): return {} def _str2dict(text: str) -> Dict: try: data = json.loads(text) except Exception as err: logger.error(f"{err} - data: {text}") data = {} return data def predict_one_file(url: str, file: Union[str, Path]) -> Dict: """ Output format: { file_path: path of file response_dict: } """ if isinstance(file, str): file = Path(file) payload = {} filename = file.name files = [ ( 'file', ( filename, open(str(file), 'rb'), 'application/pdf' ) ) ] # logger.info(f"Files: {file}") response = requests.request( "POST", url, headers=HEADERS, data=payload, files=files) response_dict = _str2dict(response.text) return { "file_path": str(file), "pred_data": response_dict } def predict(url: str, files: List[Union[str, Path]]) -> List[Dict]: """ List of {'file_path', 'response_dict'} """ preds = [] for idx, file in tqdm.tqdm(enumerate(files)): try: pred = predict_one_file(url, file) preds.append(pred) except: logger.info(f"Error at file: {file}") return preds def get_files(data_dir: str, recursive: bool = False, limit: Union[int, None] = None) -> List[Union[Path, str]]: if recursive: files = Path(data_dir).rglob("*") else: files = Path(data_dir).glob("*") if limit: files = list(files)[:limit] return files def _stem_filename(filename: str) -> str: """ Stem a file path: x/y.txt -> y """ return Path(filename).stem def format_output_api(output_api: Dict, skip_fields=['table']) -> Dict: if "pages" not in output_api: return {} pages = output_api['pages'] result = {} for page in pages: fields = page['fields'] for field_item in fields: field_label, field_value = field_item['label'], field_item['value'] if field_label in result or field_label in skip_fields: continue result[field_label] = field_value return result def combine_to_single_file(preds: List[Dict]) -> None: if len(preds) == 0: return {} combined_data = { _stem_filename(item["filename"]): item["pred_data"] for item in preds } return combined_data def eval( gt: Union[str, Dict], pred: Union[str, Dict], classes: List[str], classes_ignore: List[str] = [] ) -> Dict: eval_res = eval_kie( gt_e2e_path=gt, pred_e2e_path=pred, kie_labels=classes, skip_labels=classes_ignore ) return eval_res def get_classes(preds: Dict) -> List[str]: classes = [] for k, v in preds.items(): if v: classes = list(v.keys()) break return classes def test(): import requests url = "http://107.120.133.27:8082/predict/image" payload = {} files = [ ('file', ('(1 of 19)_HOADON_1C23TYY_50.pdf', open( '/mnt/ssd1T/hoanglv/Projects/KIE/DATA/dev_model/Invoice/processed/test/PV2/final/all_in/(1 of 19)_HOADON_1C23TYY_50.pdf', 'rb'), 'application/pdf')) ] headers = { 'accept': 'application/json', 'Authorization': 'Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2OTA2ODk4ODcsInVzZXJuYW1lIjoiYWRtaW4ifQ.Oybpc9tBsN35vCn3jzekkABDQKJT6yO1aBBJ4rMNln0' } response = requests.request( "POST", url, headers=headers, data=payload, files=files) print(response.text) # print(json.loa ds(response.text)) if __name__ == "__main__": limit = 5 run( data_dir="/mnt/ssd1T/hoanglv/Projects/KIE/DATA/dev_model/Invoice/processed/test/PV2/final/all_in", url=URL, gt_path="/mnt/ssd1T/hoanglv/Projects/KIE/DATA/dev_model/Invoice/processed/test/PV2/final/all_in.json", field_pred_file="/mnt/ssd1T/hoanglv/Projects/KIE/sdsvkie/workdirs/invoice/06062023/invoice_all_in_final_e2e_21072023_5.json", samples=limit ) # test()