432 lines
17 KiB
Python
432 lines
17 KiB
Python
|
import os
|
||
|
import re
|
||
|
import ast
|
||
|
import time
|
||
|
import json
|
||
|
import glob
|
||
|
import shutil
|
||
|
import pandas as pd
|
||
|
from tqdm import tqdm
|
||
|
from pathlib import Path
|
||
|
from datetime import datetime
|
||
|
from .ocr_metrics import eval_ocr_metric
|
||
|
|
||
|
import sys
|
||
|
# sys.path.append(os.path.dirname(__file__))
|
||
|
from sdsvkvu.utils.query.sbt_v2 import get_seller, post_process_seller
|
||
|
|
||
|
|
||
|
def read_json(file_path: str):
|
||
|
with open(file_path, 'r') as f:
|
||
|
return json.load(f)
|
||
|
|
||
|
def write_to_json(file_path, content):
|
||
|
with open(file_path, mode='w', encoding='utf8') as f:
|
||
|
json.dump(content, f, ensure_ascii=False)
|
||
|
|
||
|
|
||
|
def convert_datetime_format(date_string: str, is_gt=False) -> str:
|
||
|
# pattern_date_string = "2023-02-28"
|
||
|
output_format = "%Y-%m-%d"
|
||
|
input_format = "%d/%m/%Y"
|
||
|
# Validate the input date string format
|
||
|
pattern = r"\d{2}\/\d{2}\/\d{4}"
|
||
|
if re.match(pattern, date_string):
|
||
|
# Convert the date string to a datetime object
|
||
|
date_object = datetime.strptime(date_string, input_format)
|
||
|
# Convert the datetime object to the desired output format
|
||
|
formatted_date = date_object.strftime(output_format)
|
||
|
return formatted_date
|
||
|
return date_string
|
||
|
|
||
|
|
||
|
def normalise_retailer_name(retailer: str):
|
||
|
input_value = {
|
||
|
"text": retailer,
|
||
|
"id": 0,
|
||
|
"class": "seller",
|
||
|
"bbox": [0, 0, 0, 0],
|
||
|
}
|
||
|
output = get_seller({'seller': [input_value]})
|
||
|
|
||
|
norm_seller_name = post_process_seller(output)
|
||
|
return norm_seller_name
|
||
|
|
||
|
|
||
|
def post_processing_str(class_name: str, s: str, is_gt: bool) -> str:
|
||
|
s = str(s).replace('✪', ' ').strip()
|
||
|
if s.lower() in ['null', 'nan', "none"]:
|
||
|
return ''
|
||
|
if class_name == "purchase_date" and is_gt == True:
|
||
|
s = convert_datetime_format(s)
|
||
|
if class_name == "retailername":
|
||
|
s = normalise_retailer_name(s)
|
||
|
return s
|
||
|
|
||
|
|
||
|
def convert_groundtruth_from_csv(
|
||
|
csv_path: str,
|
||
|
save_dir: str,
|
||
|
classes: list = ["retailername", "sold_to_party", "purchase_date", "imei_number"]
|
||
|
):
|
||
|
# if isinstance(csv_path_list, str):
|
||
|
# csv_path_list = [csv_path_list]
|
||
|
|
||
|
df = pd.read_csv(csv_path)
|
||
|
|
||
|
total_output = {}
|
||
|
for _, request in df.iterrows():
|
||
|
req_id = request['requestId']
|
||
|
|
||
|
if req_id not in total_output:
|
||
|
total_output[req_id] = {k: None for k in classes}
|
||
|
total_output[req_id]["imei_number"] = []
|
||
|
|
||
|
total_output[req_id]["imei_number"].extend([request["imeiNumber"], request["imeiNumber2"]])
|
||
|
total_output[req_id]["imei_number"] = list(set(total_output[req_id]["imei_number"]))
|
||
|
|
||
|
total_output[req_id]["purchase_date"] = request["Purchase Date"]
|
||
|
total_output[req_id]["retailername"] = request["retailer"]
|
||
|
|
||
|
for req_id, output in total_output.items():
|
||
|
save_path = os.path.join(save_dir, req_id)
|
||
|
os.makedirs(save_path, exist_ok=True)
|
||
|
write_to_json(os.path.join(save_path, f"{req_id}.json"), output)
|
||
|
|
||
|
|
||
|
def convert_predict_from_csv(
|
||
|
csv_path: str,
|
||
|
save_dir: str,
|
||
|
classes: list = ["retailername", "sold_to_party", "purchase_date", "imei_number"]
|
||
|
):
|
||
|
# if isinstance(csv_path_list, str):
|
||
|
# csv_path_list = [csv_path_list]
|
||
|
|
||
|
df = pd.read_csv(csv_path)
|
||
|
|
||
|
for _, request in df.iterrows():
|
||
|
n_pages = request['pages']
|
||
|
req_id = request['request_id']
|
||
|
if not isinstance(request['doc_type'], str) or not isinstance(request['predict_result'], str):
|
||
|
print(f"[WARNING] Skipped request id {req_id}")
|
||
|
continue
|
||
|
|
||
|
doc_type_list = request['doc_type'].split(',')
|
||
|
assert n_pages == len(doc_type_list), \
|
||
|
"No. pages is different no. documents"
|
||
|
|
||
|
json_path = os.path.join(save_dir, req_id)
|
||
|
os.makedirs(json_path, exist_ok=True)
|
||
|
|
||
|
# For user_submitted_results
|
||
|
if "feedback_result" in request:
|
||
|
feedback_data = ast.literal_eval(request['feedback_result'])
|
||
|
fname = f"{req_id}.json"
|
||
|
write_to_json(os.path.join(json_path, fname), feedback_data)
|
||
|
|
||
|
# For predict_results
|
||
|
data = ast.literal_eval(request['predict_result'])['content']['document'][0]['content']
|
||
|
infer_time = float(request['ai_inference_time']) + float(request['preprocessing_time']) + 0.1
|
||
|
|
||
|
n_imei, n_invoice = 0, 0
|
||
|
for doc_type in doc_type_list:
|
||
|
output = {k: None for k in classes}
|
||
|
if not os.path.exists(json_path):
|
||
|
os.makedirs(json_path, exist_ok=True)
|
||
|
|
||
|
if doc_type == "imei":
|
||
|
for info in data:
|
||
|
if info['label'] == "imei_number":
|
||
|
output['imei_number'] = info['value'][n_imei]
|
||
|
output['processing_time'] = infer_time
|
||
|
fname = f"temp_{doc_type}_{req_id}_{n_imei}.json"
|
||
|
write_to_json(os.path.join(json_path, fname), output)
|
||
|
n_imei += 1
|
||
|
break
|
||
|
elif doc_type == "invoice":
|
||
|
for info in data:
|
||
|
if info['label'] == "imei_number":
|
||
|
continue
|
||
|
output[info['label']] = info['value']
|
||
|
output['processing_time'] = infer_time
|
||
|
fname = f"temp_{doc_type}_{req_id}_{n_invoice}.json"
|
||
|
write_to_json(os.path.join(json_path, fname), output)
|
||
|
n_invoice += 1
|
||
|
|
||
|
|
||
|
def gen_req_to_red_dict(csv_path: str):
|
||
|
df = pd.read_csv(csv_path)
|
||
|
df = df.loc[:, ["requestId", "redemptionNumber"]]
|
||
|
req_to_red = {row["requestId"]: row["redemptionNumber"] for _, row in df.iterrows()}
|
||
|
return req_to_red
|
||
|
|
||
|
|
||
|
def gen_req_to_red_dict_2(csv_path: str):
|
||
|
df = pd.read_csv(csv_path)
|
||
|
df = df.loc[:, ["request_id", "redemption_id"]]
|
||
|
req_to_red = {row["request_id"]: row["redemption_id"] for _, row in df.iterrows()}
|
||
|
return req_to_red
|
||
|
|
||
|
|
||
|
def init_csv(
|
||
|
gt_dir: str,
|
||
|
pred_dir: str,
|
||
|
req_to_red: dict,
|
||
|
):
|
||
|
list_request_id = os.listdir(gt_dir)
|
||
|
total = []
|
||
|
for request_id in list_request_id:
|
||
|
gt_path = os.path.join(gt_dir, request_id, request_id+".json")
|
||
|
if not os.path.exists(gt_path):
|
||
|
print(f"[WARNING] Skipped request id {os.path.basename(os.path.dirname(gt_path))}")
|
||
|
continue
|
||
|
gt_data = read_json(gt_path)
|
||
|
json_file_list = glob.glob(os.path.join(pred_dir, request_id, "temp_*.json"))
|
||
|
json_file_list = sorted(json_file_list, key=lambda x: int(x.split(".json")[0].split('_')[-1]))
|
||
|
n_imei, n_invoice = 0, 0
|
||
|
# if len(json_file_list) > 3:
|
||
|
# continue
|
||
|
|
||
|
for json_file in json_file_list:
|
||
|
pred_data = read_json(json_file)
|
||
|
if "imei" in json_file:
|
||
|
pred_value = pred_data['imei_number']
|
||
|
gt_value = gt_data['imei_number'][n_imei]
|
||
|
n_imei += 1
|
||
|
score = eval_ocr_metric(
|
||
|
[post_processing_str("imei_number", pred_value, is_gt=False)],
|
||
|
[post_processing_str("imei_number", gt_value, is_gt=True)],
|
||
|
metric=["one_minus_ned"]
|
||
|
)['1-N.E.D']
|
||
|
|
||
|
total.append({
|
||
|
"requestId": request_id,
|
||
|
"redemptionNumber": req_to_red[request_id],
|
||
|
"userSubmitResults": gt_value,
|
||
|
"OCRResults": pred_value,
|
||
|
"revisedResults_by_SDSRV": "",
|
||
|
"accuracy": score,
|
||
|
"processingTime (by request)": pred_data['processing_time'],
|
||
|
"class_name": "imei_number",
|
||
|
"file_path": json_file
|
||
|
})
|
||
|
|
||
|
elif "invoice" in json_file:
|
||
|
for class_name in ["retailername", "purchase_date"]:
|
||
|
pred_value = pred_data[class_name]
|
||
|
gt_value = gt_data[class_name]
|
||
|
if isinstance(gt_value, list):
|
||
|
gt_value = gt_value[0]
|
||
|
n_invoice += 1
|
||
|
|
||
|
if not isinstance(pred_value, list):
|
||
|
pred_value = [pred_value]
|
||
|
|
||
|
score = 0
|
||
|
for _pred_value in pred_value:
|
||
|
score1 = eval_ocr_metric(
|
||
|
[post_processing_str(class_name, _pred_value, is_gt=False)],
|
||
|
[post_processing_str(class_name, gt_value, is_gt=True)],
|
||
|
metric=["one_minus_ned"]
|
||
|
)['1-N.E.D']
|
||
|
score = max(score, score1)
|
||
|
|
||
|
total.append({
|
||
|
"requestId": request_id,
|
||
|
"redemptionNumber": req_to_red[request_id],
|
||
|
"userSubmitResults": gt_value,
|
||
|
"OCRResults": pred_value[0] if class_name == "retailername" else pred_value,
|
||
|
"revisedResults_by_SDSRV": "",
|
||
|
"accuracy": score,
|
||
|
"processingTime (by request)": pred_data['processing_time'],
|
||
|
"class_name": class_name,
|
||
|
"file_path": json_file
|
||
|
})
|
||
|
|
||
|
return total
|
||
|
|
||
|
|
||
|
def export_report(
|
||
|
init_csv: str,
|
||
|
):
|
||
|
df = pd.read_csv(init_csv)
|
||
|
for index, request in df.iterrows():
|
||
|
file_path = request['file_path']
|
||
|
class_name = request['class_name']
|
||
|
pred_value = request['OCRResults']
|
||
|
revised_value = read_json(file_path)[class_name]
|
||
|
if class_name == "purchase_date":
|
||
|
pred_value = ast.literal_eval(pred_value)
|
||
|
if isinstance(revised_value, list):
|
||
|
if len(revised_value) > 0:
|
||
|
revised_value = revised_value[0]
|
||
|
else:
|
||
|
revised_value = None
|
||
|
|
||
|
if len(pred_value) == 0:
|
||
|
pred_value = [None]
|
||
|
|
||
|
score = 0
|
||
|
for _pred_value in pred_value:
|
||
|
score1 = eval_ocr_metric(
|
||
|
[post_processing_str(class_name, _pred_value, is_gt=False)],
|
||
|
[post_processing_str(class_name, revised_value, is_gt=True)],
|
||
|
metric=["one_minus_ned"]
|
||
|
)['1-N.E.D']
|
||
|
score = max(score, score1)
|
||
|
else:
|
||
|
score = eval_ocr_metric(
|
||
|
[post_processing_str(class_name, pred_value, is_gt=False)],
|
||
|
[post_processing_str(class_name, revised_value, is_gt=True)],
|
||
|
metric=["one_minus_ned"]
|
||
|
)['1-N.E.D']
|
||
|
|
||
|
|
||
|
df.at[index, "revisedResults_by_SDSRV"] = revised_value
|
||
|
df.at[index, "accuracy"] = score
|
||
|
|
||
|
return df
|
||
|
|
||
|
|
||
|
def pick_sample_to_revise(
|
||
|
ocr_accuracy: list,
|
||
|
gt_dir: str,
|
||
|
save_dir: str
|
||
|
):
|
||
|
empty_err_path = os.path.join(save_dir, "empty_results")
|
||
|
other_err_path = os.path.join(save_dir, "diff_results")
|
||
|
os.makedirs(empty_err_path, exist_ok=True)
|
||
|
os.makedirs(other_err_path, exist_ok=True)
|
||
|
for request in ocr_accuracy:
|
||
|
score = request['accuracy']
|
||
|
json_path = request['file_path']
|
||
|
request_id = request['requestId']
|
||
|
|
||
|
img_path_folder = os.path.join(gt_dir, Path(json_path).parts[-2], Path(json_path).parts[-1])
|
||
|
img_path = [ff for ff in glob.glob(img_path_folder.replace(".json", ".*")) if ".json" not in ff]
|
||
|
|
||
|
if len(img_path) == 0:
|
||
|
print(f"[WARNING] Skipped request id {request_id}")
|
||
|
continue
|
||
|
img_path = img_path[0]
|
||
|
# img_path = [ff for ff in glob.glob(json_path.replace(".json", ".*"))][0]
|
||
|
|
||
|
if score == 0:
|
||
|
save_path = os.path.join(empty_err_path, request_id)
|
||
|
elif score < 1:
|
||
|
save_path = os.path.join(other_err_path, request_id)
|
||
|
else:
|
||
|
continue
|
||
|
os.makedirs(save_path, exist_ok=True)
|
||
|
shutil.copy(img_path, save_path)
|
||
|
shutil.copy(json_path, save_path)
|
||
|
|
||
|
def merge_revised_sample(
|
||
|
revised_path_list: list,
|
||
|
save_dir: str
|
||
|
):
|
||
|
if not isinstance(revised_path_list, list):
|
||
|
revised_path_list = [revised_path_list]
|
||
|
|
||
|
for revised_path in revised_path_list:
|
||
|
list_request = [os.path.basename(ff) for ff in os.listdir(revised_path)]
|
||
|
for request in list_request:
|
||
|
file_list = glob.glob(os.path.join(revised_path, request, "*.json*"))
|
||
|
for file_path in file_list:
|
||
|
# shutil.copyfile(file_path, os.path.join(save_path, request))
|
||
|
os.system(f"sudo cp {file_path} {os.path.join(save_dir, request)}")
|
||
|
|
||
|
def calculate_average_by_column(df, column_name):
|
||
|
df = df.groupby(by=["requestId"])
|
||
|
time_list = []
|
||
|
for req, sub_df in df:
|
||
|
if len(sub_df) > 0:
|
||
|
time_list.append(sub_df.iloc[0][column_name])
|
||
|
if len(time_list) > 0:
|
||
|
return sum(time_list)/len(time_list)
|
||
|
return 0
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
save_path = "/mnt/hdd4T/TannedCung/OCR/Data/SBT_for_acc/15Jan"
|
||
|
save_csv = "logs/eval_20240115"
|
||
|
csv_path = "/mnt/hdd4T/TannedCung/OCR/Data/SBT_for_acc/15Jan.csv"
|
||
|
csv_path_end_user = "logs/eval_20240115/OCR_15Jan2024.csv"
|
||
|
|
||
|
# Step 1: Convert a csv file to get user submitted results for each request
|
||
|
print("[INFO] Starting convert csv from customer to json")
|
||
|
os.system(f"sudo chmod -R 777 {save_path}")
|
||
|
convert_groundtruth_from_csv(csv_path=csv_path_end_user, save_dir=save_path)
|
||
|
print("[INFO] Converted")
|
||
|
|
||
|
# # Step 2: Convert a csv file to get predict OCR results for each image
|
||
|
print("[INFO] Starting convert csv from SDSV to json")
|
||
|
convert_predict_from_csv(csv_path=csv_path, save_dir=save_path)
|
||
|
print("[INFO] Converted")
|
||
|
|
||
|
# # Step 3: Gen initial csv file and calculate OCR result between submitted results and ocr results
|
||
|
print("[INFO] Starting generate csv to get performance")
|
||
|
gt_path = save_path
|
||
|
pred_path = save_path
|
||
|
req_to_red_dict = gen_req_to_red_dict(csv_path_end_user)
|
||
|
init_data = init_csv(gt_dir=gt_path, pred_dir=pred_path, req_to_red=req_to_red_dict)
|
||
|
pd.DataFrame(init_data).to_csv(os.path.join(save_csv, "init1.csv"), index=False)
|
||
|
print("[INFO] Done")
|
||
|
|
||
|
# # Step 4: Split requests whose accuracy is less than 1 to revise
|
||
|
# print("[INFO] Starting split data to review")
|
||
|
# revised_path = os.path.join(save_csv, "revised")
|
||
|
# # shutil.rmtree(revised_path)
|
||
|
# pick_sample_to_revise(ocr_accuracy=init_data, gt_dir=save_path, save_dir=revised_path)
|
||
|
# print("[INFO] Done")
|
||
|
|
||
|
# # Step 5: Merge revised results to gt folder
|
||
|
# print("[INFO] Merging revised data to ground truth folder")
|
||
|
# revised_path = os.path.join(save_csv, "revised")
|
||
|
# revised_path = [f'{revised_path}/empty_results', f'{revised_path}/diff_results']
|
||
|
# merge_revised_sample(revised_path_list=revised_path, save_dir=save_path)
|
||
|
# print("Done")
|
||
|
|
||
|
# # Step 6: Caculate OCR result between ocr results and revised results
|
||
|
# print("[INFO] Exporting OCR report")
|
||
|
# init_csv_path = os.path.join(save_csv, "init1.csv")
|
||
|
# report = export_report(init_csv=init_csv_path)
|
||
|
# error_path = os.path.join(save_csv, "errors")
|
||
|
# pick_sample_to_revise(ocr_accuracy=report[report.accuracy < 0.75].to_dict('records'), gt_dir=save_path, save_dir=error_path)
|
||
|
|
||
|
# n_total_images = len(report)
|
||
|
# n_bad_images = len(report[report.accuracy < 0.75])
|
||
|
# average_acc = report[report.accuracy >= 0.75]['accuracy'].mean()
|
||
|
|
||
|
# print("Total requests:", len(report['requestId'].unique()))
|
||
|
# print("Total images:", n_total_images)
|
||
|
# print("No. imei images:", len(report[report.class_name == "imei_number"]))
|
||
|
# print("No. invoice images:", len(report[report.class_name == "retailername"]))
|
||
|
# print("No. bad quality images:", n_bad_images)
|
||
|
# print("No. valid images:", n_total_images - n_bad_images)
|
||
|
# print("No. per of bad quality images:", 100*n_bad_images/n_total_images)
|
||
|
# print("Average accuracy:", 100*average_acc)
|
||
|
|
||
|
# last_row = n_total_images
|
||
|
# report.at[last_row, "requestId"] = "Total requests:"
|
||
|
# report.at[last_row, "redemptionNumber"] = len(report['requestId'].unique())
|
||
|
# report.at[last_row+1, "requestId"] = "Total images:"
|
||
|
# report.at[last_row+1, "redemptionNumber"] = n_total_images
|
||
|
# report.at[last_row+2, "requestId"] = "No. imei images:"
|
||
|
# report.at[last_row+2, "redemptionNumber"] = len(report[report.class_name == "imei_number"])
|
||
|
# report.at[last_row+3, "requestId"] = "No. invoice images:"
|
||
|
# report.at[last_row+3, "redemptionNumber"] = len(report[report.class_name == "retailername"])
|
||
|
# report.at[last_row+4, "requestId"] = "No. bad quality images:"
|
||
|
# report.at[last_row+4, "redemptionNumber"] = n_bad_images
|
||
|
# report.at[last_row+5, "requestId"] = "No. valid images:"
|
||
|
# report.at[last_row+5, "redemptionNumber"] = n_total_images - n_bad_images
|
||
|
# report.at[last_row+6, "requestId"] = "No. per of bad quality images:"
|
||
|
# report.at[last_row+6, "redemptionNumber"] = 100*n_bad_images/n_total_images
|
||
|
# report.at[last_row+7, "requestId"] = "Average accuracy:"
|
||
|
# report.at[last_row+7, "redemptionNumber"] = 100*average_acc
|
||
|
|
||
|
|
||
|
# report.drop(columns=["file_path", "class_name"]).to_csv(os.path.join(save_csv, f"SBT_report_{time.strftime('%Y%m%d')}.csv"), index=False)
|
||
|
# print("[INFO] Done")
|
||
|
|
||
|
|