sbt-idp/cope2n-api/fwd_api/utils/ocr_utils/sbt_report.py

432 lines
17 KiB
Python
Raw Normal View History

2024-01-31 03:00:18 +00:00
import os
import re
import ast
import time
import json
import glob
import shutil
import pandas as pd
from tqdm import tqdm
from pathlib import Path
from datetime import datetime
from .ocr_metrics import eval_ocr_metric
import sys
# sys.path.append(os.path.dirname(__file__))
from sdsvkvu.utils.query.sbt_v2 import get_seller, post_process_seller
def read_json(file_path: str):
with open(file_path, 'r') as f:
return json.load(f)
def write_to_json(file_path, content):
with open(file_path, mode='w', encoding='utf8') as f:
json.dump(content, f, ensure_ascii=False)
def convert_datetime_format(date_string: str, is_gt=False) -> str:
# pattern_date_string = "2023-02-28"
output_format = "%Y-%m-%d"
input_format = "%d/%m/%Y"
# Validate the input date string format
pattern = r"\d{2}\/\d{2}\/\d{4}"
if re.match(pattern, date_string):
# Convert the date string to a datetime object
date_object = datetime.strptime(date_string, input_format)
# Convert the datetime object to the desired output format
formatted_date = date_object.strftime(output_format)
return formatted_date
return date_string
def normalise_retailer_name(retailer: str):
input_value = {
"text": retailer,
"id": 0,
"class": "seller",
"bbox": [0, 0, 0, 0],
}
output = get_seller({'seller': [input_value]})
norm_seller_name = post_process_seller(output)
return norm_seller_name
def post_processing_str(class_name: str, s: str, is_gt: bool) -> str:
s = str(s).replace('', ' ').strip()
if s.lower() in ['null', 'nan', "none"]:
return ''
if class_name == "purchase_date" and is_gt == True:
s = convert_datetime_format(s)
if class_name == "retailername":
s = normalise_retailer_name(s)
return s
def convert_groundtruth_from_csv(
csv_path: str,
save_dir: str,
classes: list = ["retailername", "sold_to_party", "purchase_date", "imei_number"]
):
# if isinstance(csv_path_list, str):
# csv_path_list = [csv_path_list]
df = pd.read_csv(csv_path)
total_output = {}
for _, request in df.iterrows():
req_id = request['requestId']
if req_id not in total_output:
total_output[req_id] = {k: None for k in classes}
total_output[req_id]["imei_number"] = []
total_output[req_id]["imei_number"].extend([request["imeiNumber"], request["imeiNumber2"]])
total_output[req_id]["imei_number"] = list(set(total_output[req_id]["imei_number"]))
total_output[req_id]["purchase_date"] = request["Purchase Date"]
total_output[req_id]["retailername"] = request["retailer"]
for req_id, output in total_output.items():
save_path = os.path.join(save_dir, req_id)
os.makedirs(save_path, exist_ok=True)
write_to_json(os.path.join(save_path, f"{req_id}.json"), output)
def convert_predict_from_csv(
csv_path: str,
save_dir: str,
classes: list = ["retailername", "sold_to_party", "purchase_date", "imei_number"]
):
# if isinstance(csv_path_list, str):
# csv_path_list = [csv_path_list]
df = pd.read_csv(csv_path)
for _, request in df.iterrows():
n_pages = request['pages']
req_id = request['request_id']
if not isinstance(request['doc_type'], str) or not isinstance(request['predict_result'], str):
print(f"[WARNING] Skipped request id {req_id}")
continue
doc_type_list = request['doc_type'].split(',')
assert n_pages == len(doc_type_list), \
"No. pages is different no. documents"
json_path = os.path.join(save_dir, req_id)
os.makedirs(json_path, exist_ok=True)
# For user_submitted_results
if "feedback_result" in request:
feedback_data = ast.literal_eval(request['feedback_result'])
fname = f"{req_id}.json"
write_to_json(os.path.join(json_path, fname), feedback_data)
# For predict_results
data = ast.literal_eval(request['predict_result'])['content']['document'][0]['content']
infer_time = float(request['ai_inference_time']) + float(request['preprocessing_time']) + 0.1
n_imei, n_invoice = 0, 0
for doc_type in doc_type_list:
output = {k: None for k in classes}
if not os.path.exists(json_path):
os.makedirs(json_path, exist_ok=True)
if doc_type == "imei":
for info in data:
if info['label'] == "imei_number":
output['imei_number'] = info['value'][n_imei]
output['processing_time'] = infer_time
fname = f"temp_{doc_type}_{req_id}_{n_imei}.json"
write_to_json(os.path.join(json_path, fname), output)
n_imei += 1
break
elif doc_type == "invoice":
for info in data:
if info['label'] == "imei_number":
continue
output[info['label']] = info['value']
output['processing_time'] = infer_time
fname = f"temp_{doc_type}_{req_id}_{n_invoice}.json"
write_to_json(os.path.join(json_path, fname), output)
n_invoice += 1
def gen_req_to_red_dict(csv_path: str):
df = pd.read_csv(csv_path)
df = df.loc[:, ["requestId", "redemptionNumber"]]
req_to_red = {row["requestId"]: row["redemptionNumber"] for _, row in df.iterrows()}
return req_to_red
def gen_req_to_red_dict_2(csv_path: str):
df = pd.read_csv(csv_path)
df = df.loc[:, ["request_id", "redemption_id"]]
req_to_red = {row["request_id"]: row["redemption_id"] for _, row in df.iterrows()}
return req_to_red
def init_csv(
gt_dir: str,
pred_dir: str,
req_to_red: dict,
):
list_request_id = os.listdir(gt_dir)
total = []
for request_id in list_request_id:
gt_path = os.path.join(gt_dir, request_id, request_id+".json")
if not os.path.exists(gt_path):
print(f"[WARNING] Skipped request id {os.path.basename(os.path.dirname(gt_path))}")
continue
gt_data = read_json(gt_path)
json_file_list = glob.glob(os.path.join(pred_dir, request_id, "temp_*.json"))
json_file_list = sorted(json_file_list, key=lambda x: int(x.split(".json")[0].split('_')[-1]))
n_imei, n_invoice = 0, 0
# if len(json_file_list) > 3:
# continue
for json_file in json_file_list:
pred_data = read_json(json_file)
if "imei" in json_file:
pred_value = pred_data['imei_number']
gt_value = gt_data['imei_number'][n_imei]
n_imei += 1
score = eval_ocr_metric(
[post_processing_str("imei_number", pred_value, is_gt=False)],
[post_processing_str("imei_number", gt_value, is_gt=True)],
metric=["one_minus_ned"]
)['1-N.E.D']
total.append({
"requestId": request_id,
"redemptionNumber": req_to_red[request_id],
"userSubmitResults": gt_value,
"OCRResults": pred_value,
"revisedResults_by_SDSRV": "",
"accuracy": score,
"processingTime (by request)": pred_data['processing_time'],
"class_name": "imei_number",
"file_path": json_file
})
elif "invoice" in json_file:
for class_name in ["retailername", "purchase_date"]:
pred_value = pred_data[class_name]
gt_value = gt_data[class_name]
if isinstance(gt_value, list):
gt_value = gt_value[0]
n_invoice += 1
if not isinstance(pred_value, list):
pred_value = [pred_value]
score = 0
for _pred_value in pred_value:
score1 = eval_ocr_metric(
[post_processing_str(class_name, _pred_value, is_gt=False)],
[post_processing_str(class_name, gt_value, is_gt=True)],
metric=["one_minus_ned"]
)['1-N.E.D']
score = max(score, score1)
total.append({
"requestId": request_id,
"redemptionNumber": req_to_red[request_id],
"userSubmitResults": gt_value,
"OCRResults": pred_value[0] if class_name == "retailername" else pred_value,
"revisedResults_by_SDSRV": "",
"accuracy": score,
"processingTime (by request)": pred_data['processing_time'],
"class_name": class_name,
"file_path": json_file
})
return total
def export_report(
init_csv: str,
):
df = pd.read_csv(init_csv)
for index, request in df.iterrows():
file_path = request['file_path']
class_name = request['class_name']
pred_value = request['OCRResults']
revised_value = read_json(file_path)[class_name]
if class_name == "purchase_date":
pred_value = ast.literal_eval(pred_value)
if isinstance(revised_value, list):
if len(revised_value) > 0:
revised_value = revised_value[0]
else:
revised_value = None
if len(pred_value) == 0:
pred_value = [None]
score = 0
for _pred_value in pred_value:
score1 = eval_ocr_metric(
[post_processing_str(class_name, _pred_value, is_gt=False)],
[post_processing_str(class_name, revised_value, is_gt=True)],
metric=["one_minus_ned"]
)['1-N.E.D']
score = max(score, score1)
else:
score = eval_ocr_metric(
[post_processing_str(class_name, pred_value, is_gt=False)],
[post_processing_str(class_name, revised_value, is_gt=True)],
metric=["one_minus_ned"]
)['1-N.E.D']
df.at[index, "revisedResults_by_SDSRV"] = revised_value
df.at[index, "accuracy"] = score
return df
def pick_sample_to_revise(
ocr_accuracy: list,
gt_dir: str,
save_dir: str
):
empty_err_path = os.path.join(save_dir, "empty_results")
other_err_path = os.path.join(save_dir, "diff_results")
os.makedirs(empty_err_path, exist_ok=True)
os.makedirs(other_err_path, exist_ok=True)
for request in ocr_accuracy:
score = request['accuracy']
json_path = request['file_path']
request_id = request['requestId']
img_path_folder = os.path.join(gt_dir, Path(json_path).parts[-2], Path(json_path).parts[-1])
img_path = [ff for ff in glob.glob(img_path_folder.replace(".json", ".*")) if ".json" not in ff]
if len(img_path) == 0:
print(f"[WARNING] Skipped request id {request_id}")
continue
img_path = img_path[0]
# img_path = [ff for ff in glob.glob(json_path.replace(".json", ".*"))][0]
if score == 0:
save_path = os.path.join(empty_err_path, request_id)
elif score < 1:
save_path = os.path.join(other_err_path, request_id)
else:
continue
os.makedirs(save_path, exist_ok=True)
shutil.copy(img_path, save_path)
shutil.copy(json_path, save_path)
def merge_revised_sample(
revised_path_list: list,
save_dir: str
):
if not isinstance(revised_path_list, list):
revised_path_list = [revised_path_list]
for revised_path in revised_path_list:
list_request = [os.path.basename(ff) for ff in os.listdir(revised_path)]
for request in list_request:
file_list = glob.glob(os.path.join(revised_path, request, "*.json*"))
for file_path in file_list:
# shutil.copyfile(file_path, os.path.join(save_path, request))
os.system(f"sudo cp {file_path} {os.path.join(save_dir, request)}")
def calculate_average_by_column(df, column_name):
df = df.groupby(by=["requestId"])
time_list = []
for req, sub_df in df:
if len(sub_df) > 0:
time_list.append(sub_df.iloc[0][column_name])
if len(time_list) > 0:
return sum(time_list)/len(time_list)
return 0
if __name__ == "__main__":
save_path = "/mnt/hdd4T/TannedCung/OCR/Data/SBT_for_acc/15Jan"
save_csv = "logs/eval_20240115"
csv_path = "/mnt/hdd4T/TannedCung/OCR/Data/SBT_for_acc/15Jan.csv"
csv_path_end_user = "logs/eval_20240115/OCR_15Jan2024.csv"
# Step 1: Convert a csv file to get user submitted results for each request
print("[INFO] Starting convert csv from customer to json")
os.system(f"sudo chmod -R 777 {save_path}")
convert_groundtruth_from_csv(csv_path=csv_path_end_user, save_dir=save_path)
print("[INFO] Converted")
# # Step 2: Convert a csv file to get predict OCR results for each image
print("[INFO] Starting convert csv from SDSV to json")
convert_predict_from_csv(csv_path=csv_path, save_dir=save_path)
print("[INFO] Converted")
# # Step 3: Gen initial csv file and calculate OCR result between submitted results and ocr results
print("[INFO] Starting generate csv to get performance")
gt_path = save_path
pred_path = save_path
req_to_red_dict = gen_req_to_red_dict(csv_path_end_user)
init_data = init_csv(gt_dir=gt_path, pred_dir=pred_path, req_to_red=req_to_red_dict)
pd.DataFrame(init_data).to_csv(os.path.join(save_csv, "init1.csv"), index=False)
print("[INFO] Done")
# # Step 4: Split requests whose accuracy is less than 1 to revise
# print("[INFO] Starting split data to review")
# revised_path = os.path.join(save_csv, "revised")
# # shutil.rmtree(revised_path)
# pick_sample_to_revise(ocr_accuracy=init_data, gt_dir=save_path, save_dir=revised_path)
# print("[INFO] Done")
# # Step 5: Merge revised results to gt folder
# print("[INFO] Merging revised data to ground truth folder")
# revised_path = os.path.join(save_csv, "revised")
# revised_path = [f'{revised_path}/empty_results', f'{revised_path}/diff_results']
# merge_revised_sample(revised_path_list=revised_path, save_dir=save_path)
# print("Done")
# # Step 6: Caculate OCR result between ocr results and revised results
# print("[INFO] Exporting OCR report")
# init_csv_path = os.path.join(save_csv, "init1.csv")
# report = export_report(init_csv=init_csv_path)
# error_path = os.path.join(save_csv, "errors")
# pick_sample_to_revise(ocr_accuracy=report[report.accuracy < 0.75].to_dict('records'), gt_dir=save_path, save_dir=error_path)
# n_total_images = len(report)
# n_bad_images = len(report[report.accuracy < 0.75])
# average_acc = report[report.accuracy >= 0.75]['accuracy'].mean()
# print("Total requests:", len(report['requestId'].unique()))
# print("Total images:", n_total_images)
# print("No. imei images:", len(report[report.class_name == "imei_number"]))
# print("No. invoice images:", len(report[report.class_name == "retailername"]))
# print("No. bad quality images:", n_bad_images)
# print("No. valid images:", n_total_images - n_bad_images)
# print("No. per of bad quality images:", 100*n_bad_images/n_total_images)
# print("Average accuracy:", 100*average_acc)
# last_row = n_total_images
# report.at[last_row, "requestId"] = "Total requests:"
# report.at[last_row, "redemptionNumber"] = len(report['requestId'].unique())
# report.at[last_row+1, "requestId"] = "Total images:"
# report.at[last_row+1, "redemptionNumber"] = n_total_images
# report.at[last_row+2, "requestId"] = "No. imei images:"
# report.at[last_row+2, "redemptionNumber"] = len(report[report.class_name == "imei_number"])
# report.at[last_row+3, "requestId"] = "No. invoice images:"
# report.at[last_row+3, "redemptionNumber"] = len(report[report.class_name == "retailername"])
# report.at[last_row+4, "requestId"] = "No. bad quality images:"
# report.at[last_row+4, "redemptionNumber"] = n_bad_images
# report.at[last_row+5, "requestId"] = "No. valid images:"
# report.at[last_row+5, "redemptionNumber"] = n_total_images - n_bad_images
# report.at[last_row+6, "requestId"] = "No. per of bad quality images:"
# report.at[last_row+6, "redemptionNumber"] = 100*n_bad_images/n_total_images
# report.at[last_row+7, "requestId"] = "Average accuracy:"
# report.at[last_row+7, "redemptionNumber"] = 100*average_acc
# report.drop(columns=["file_path", "class_name"]).to_csv(os.path.join(save_csv, f"SBT_report_{time.strftime('%Y%m%d')}.csv"), index=False)
# print("[INFO] Done")