sbt-idp/cope2n-api/fwd_api/utils/ocr_utils/sbt_report.py

439 lines
18 KiB
Python
Raw Normal View History

2024-01-31 03:00:18 +00:00
import os
import re
import ast
import time
import json
import glob
import shutil
from django.conf import settings
2024-01-31 03:00:18 +00:00
import pandas as pd
from tqdm import tqdm
from pathlib import Path
from datetime import datetime
from .ocr_metrics import eval_ocr_metric
import sys
# sys.path.append(os.path.dirname(__file__))
from sdsvkvu.utils.query.sbt_v2 import get_seller, post_process_seller
2024-06-26 07:58:24 +00:00
import logging
logger = logging.getLogger(__name__)
2024-01-31 03:00:18 +00:00
def read_json(file_path: str):
with open(file_path, 'r') as f:
return json.load(f)
def write_to_json(file_path, content):
with open(file_path, mode='w', encoding='utf8') as f:
json.dump(content, f, ensure_ascii=False)
def convert_datetime_format(date_string: str, is_gt=False) -> str:
# pattern_date_string = "2023-02-28"
output_format = "%Y-%m-%d"
input_format = "%d/%m/%Y"
# Validate the input date string format
pattern = r"\d{2}\/\d{2}\/\d{4}"
if re.match(pattern, date_string):
# Convert the date string to a datetime object
date_object = datetime.strptime(date_string, input_format)
# Convert the datetime object to the desired output format
formatted_date = date_object.strftime(output_format)
return formatted_date
return date_string
2024-04-05 11:50:41 +00:00
def normalise_retailer_name(retailer: str, sub: str):
2024-01-31 03:00:18 +00:00
input_value = {
"text": retailer,
"id": 0,
"class": "seller",
"bbox": [0, 0, 0, 0],
}
2024-04-05 11:50:41 +00:00
output = get_seller({'seller': [input_value]}, sub)
2024-01-31 03:00:18 +00:00
2024-04-05 11:50:41 +00:00
norm_seller_name = post_process_seller(output, sub)
2024-01-31 03:00:18 +00:00
return norm_seller_name
2024-04-05 11:50:41 +00:00
def post_processing_str(class_name: str, s: str, is_gt: bool, sub: str) -> str:
2024-01-31 03:00:18 +00:00
s = str(s).replace('', ' ').strip()
if s.lower() in ['null', 'nan', "none"]:
return ''
if class_name == "purchase_date" and is_gt == True:
s = convert_datetime_format(s)
if class_name == "retailername":
2024-04-05 11:50:41 +00:00
s = normalise_retailer_name(s, sub)
if class_name == "imei_number" and isinstance(s, str):
if len(s) > settings.IMEI_MAX_LENGHT:
s = s[:settings.IMEI_MAX_LENGHT]
2024-01-31 03:00:18 +00:00
return s
def convert_groundtruth_from_csv(
csv_path: str,
save_dir: str,
classes: list = ["retailername", "sold_to_party", "purchase_date", "imei_number"]
):
# if isinstance(csv_path_list, str):
# csv_path_list = [csv_path_list]
df = pd.read_csv(csv_path)
total_output = {}
for _, request in df.iterrows():
req_id = request['requestId']
if req_id not in total_output:
total_output[req_id] = {k: None for k in classes}
total_output[req_id]["imei_number"] = []
total_output[req_id]["imei_number"].extend([request["imeiNumber"], request["imeiNumber2"]])
total_output[req_id]["imei_number"] = list(set(total_output[req_id]["imei_number"]))
total_output[req_id]["purchase_date"] = request["Purchase Date"]
total_output[req_id]["retailername"] = request["retailer"]
for req_id, output in total_output.items():
save_path = os.path.join(save_dir, req_id)
os.makedirs(save_path, exist_ok=True)
write_to_json(os.path.join(save_path, f"{req_id}.json"), output)
def convert_predict_from_csv(
csv_path: str,
save_dir: str,
classes: list = ["retailername", "sold_to_party", "purchase_date", "imei_number"]
):
# if isinstance(csv_path_list, str):
# csv_path_list = [csv_path_list]
df = pd.read_csv(csv_path)
for _, request in df.iterrows():
n_pages = request['pages']
req_id = request['request_id']
if not isinstance(request['doc_type'], str) or not isinstance(request['predict_result'], str):
2024-06-26 07:58:24 +00:00
logger.warning(f"] Skipped request id {req_id}")
2024-01-31 03:00:18 +00:00
continue
doc_type_list = request['doc_type'].split(',')
assert n_pages == len(doc_type_list), \
"No. pages is different no. documents"
json_path = os.path.join(save_dir, req_id)
os.makedirs(json_path, exist_ok=True)
# For user_submitted_results
if "feedback_result" in request:
feedback_data = ast.literal_eval(request['feedback_result'])
fname = f"{req_id}.json"
write_to_json(os.path.join(json_path, fname), feedback_data)
# For predict_results
data = ast.literal_eval(request['predict_result'])['content']['document'][0]['content']
infer_time = float(request['ai_inference_time']) + float(request['preprocessing_time']) + 0.1
n_imei, n_invoice = 0, 0
for doc_type in doc_type_list:
output = {k: None for k in classes}
if not os.path.exists(json_path):
os.makedirs(json_path, exist_ok=True)
if doc_type == "imei":
for info in data:
if info['label'] == "imei_number":
output['imei_number'] = info['value'][n_imei]
output['processing_time'] = infer_time
fname = f"temp_{doc_type}_{req_id}_{n_imei}.json"
write_to_json(os.path.join(json_path, fname), output)
n_imei += 1
break
elif doc_type == "invoice":
for info in data:
if info['label'] == "imei_number":
continue
output[info['label']] = info['value']
output['processing_time'] = infer_time
fname = f"temp_{doc_type}_{req_id}_{n_invoice}.json"
write_to_json(os.path.join(json_path, fname), output)
n_invoice += 1
def gen_req_to_red_dict(csv_path: str):
df = pd.read_csv(csv_path)
df = df.loc[:, ["requestId", "redemptionNumber"]]
req_to_red = {row["requestId"]: row["redemptionNumber"] for _, row in df.iterrows()}
return req_to_red
def gen_req_to_red_dict_2(csv_path: str):
df = pd.read_csv(csv_path)
df = df.loc[:, ["request_id", "redemption_id"]]
req_to_red = {row["request_id"]: row["redemption_id"] for _, row in df.iterrows()}
return req_to_red
def init_csv(
gt_dir: str,
pred_dir: str,
req_to_red: dict,
):
list_request_id = os.listdir(gt_dir)
total = []
for request_id in list_request_id:
gt_path = os.path.join(gt_dir, request_id, request_id+".json")
if not os.path.exists(gt_path):
2024-06-26 07:58:24 +00:00
logger.warning(f"] Skipped request id {os.path.basename(os.path.dirname(gt_path))}")
2024-01-31 03:00:18 +00:00
continue
gt_data = read_json(gt_path)
json_file_list = glob.glob(os.path.join(pred_dir, request_id, "temp_*.json"))
json_file_list = sorted(json_file_list, key=lambda x: int(x.split(".json")[0].split('_')[-1]))
n_imei, n_invoice = 0, 0
# if len(json_file_list) > 3:
# continue
for json_file in json_file_list:
pred_data = read_json(json_file)
if "imei" in json_file:
pred_value = pred_data['imei_number']
gt_value = gt_data['imei_number'][n_imei]
n_imei += 1
score = eval_ocr_metric(
[post_processing_str("imei_number", pred_value, is_gt=False)],
[post_processing_str("imei_number", gt_value, is_gt=True)],
metric=["one_minus_ned"]
)['1-N.E.D']
total.append({
"requestId": request_id,
"redemptionNumber": req_to_red[request_id],
"userSubmitResults": gt_value,
"OCRResults": pred_value,
"revisedResults_by_SDSRV": "",
"accuracy": score,
"processingTime (by request)": pred_data['processing_time'],
"class_name": "imei_number",
"file_path": json_file
})
elif "invoice" in json_file:
for class_name in ["retailername", "purchase_date"]:
pred_value = pred_data[class_name]
gt_value = gt_data[class_name]
if isinstance(gt_value, list):
gt_value = gt_value[0]
n_invoice += 1
if not isinstance(pred_value, list):
pred_value = [pred_value]
score = 0
for _pred_value in pred_value:
score1 = eval_ocr_metric(
[post_processing_str(class_name, _pred_value, is_gt=False)],
[post_processing_str(class_name, gt_value, is_gt=True)],
metric=["one_minus_ned"]
)['1-N.E.D']
score = max(score, score1)
total.append({
"requestId": request_id,
"redemptionNumber": req_to_red[request_id],
"userSubmitResults": gt_value,
"OCRResults": pred_value[0] if class_name == "retailername" else pred_value,
"revisedResults_by_SDSRV": "",
"accuracy": score,
"processingTime (by request)": pred_data['processing_time'],
"class_name": class_name,
"file_path": json_file
})
return total
def export_report(
init_csv: str,
):
df = pd.read_csv(init_csv)
for index, request in df.iterrows():
file_path = request['file_path']
class_name = request['class_name']
pred_value = request['OCRResults']
revised_value = read_json(file_path)[class_name]
if class_name == "purchase_date":
pred_value = ast.literal_eval(pred_value)
if isinstance(revised_value, list):
if len(revised_value) > 0:
revised_value = revised_value[0]
else:
revised_value = None
if len(pred_value) == 0:
pred_value = [None]
score = 0
for _pred_value in pred_value:
score1 = eval_ocr_metric(
[post_processing_str(class_name, _pred_value, is_gt=False)],
[post_processing_str(class_name, revised_value, is_gt=True)],
metric=["one_minus_ned"]
)['1-N.E.D']
score = max(score, score1)
else:
score = eval_ocr_metric(
[post_processing_str(class_name, pred_value, is_gt=False)],
[post_processing_str(class_name, revised_value, is_gt=True)],
metric=["one_minus_ned"]
)['1-N.E.D']
df.at[index, "revisedResults_by_SDSRV"] = revised_value
df.at[index, "accuracy"] = score
return df
def pick_sample_to_revise(
ocr_accuracy: list,
gt_dir: str,
save_dir: str
):
empty_err_path = os.path.join(save_dir, "empty_results")
other_err_path = os.path.join(save_dir, "diff_results")
os.makedirs(empty_err_path, exist_ok=True)
os.makedirs(other_err_path, exist_ok=True)
for request in ocr_accuracy:
score = request['accuracy']
json_path = request['file_path']
request_id = request['requestId']
img_path_folder = os.path.join(gt_dir, Path(json_path).parts[-2], Path(json_path).parts[-1])
img_path = [ff for ff in glob.glob(img_path_folder.replace(".json", ".*")) if ".json" not in ff]
if len(img_path) == 0:
2024-06-26 07:58:24 +00:00
logger.warning(f"] Skipped request id {request_id}")
2024-01-31 03:00:18 +00:00
continue
img_path = img_path[0]
# img_path = [ff for ff in glob.glob(json_path.replace(".json", ".*"))][0]
if score == 0:
save_path = os.path.join(empty_err_path, request_id)
elif score < 1:
save_path = os.path.join(other_err_path, request_id)
else:
continue
os.makedirs(save_path, exist_ok=True)
shutil.copy(img_path, save_path)
shutil.copy(json_path, save_path)
def merge_revised_sample(
revised_path_list: list,
save_dir: str
):
if not isinstance(revised_path_list, list):
revised_path_list = [revised_path_list]
for revised_path in revised_path_list:
list_request = [os.path.basename(ff) for ff in os.listdir(revised_path)]
for request in list_request:
file_list = glob.glob(os.path.join(revised_path, request, "*.json*"))
for file_path in file_list:
# shutil.copyfile(file_path, os.path.join(save_path, request))
os.system(f"sudo cp {file_path} {os.path.join(save_dir, request)}")
def calculate_average_by_column(df, column_name):
df = df.groupby(by=["requestId"])
time_list = []
for req, sub_df in df:
if len(sub_df) > 0:
time_list.append(sub_df.iloc[0][column_name])
if len(time_list) > 0:
return sum(time_list)/len(time_list)
return 0
if __name__ == "__main__":
save_path = "/mnt/hdd4T/TannedCung/OCR/Data/SBT_for_acc/15Jan"
save_csv = "logs/eval_20240115"
csv_path = "/mnt/hdd4T/TannedCung/OCR/Data/SBT_for_acc/15Jan.csv"
csv_path_end_user = "logs/eval_20240115/OCR_15Jan2024.csv"
# Step 1: Convert a csv file to get user submitted results for each request
2024-06-26 07:58:24 +00:00
logger.info(f" Starting convert csv from customer to json")
2024-01-31 03:00:18 +00:00
os.system(f"sudo chmod -R 777 {save_path}")
convert_groundtruth_from_csv(csv_path=csv_path_end_user, save_dir=save_path)
2024-06-26 07:58:24 +00:00
logger.info(f" Converted")
2024-01-31 03:00:18 +00:00
# # Step 2: Convert a csv file to get predict OCR results for each image
2024-06-26 07:58:24 +00:00
logger.info(f" Starting convert csv from SDSV to json")
2024-01-31 03:00:18 +00:00
convert_predict_from_csv(csv_path=csv_path, save_dir=save_path)
2024-06-26 07:58:24 +00:00
logger.info(f" Converted")
2024-01-31 03:00:18 +00:00
# # Step 3: Gen initial csv file and calculate OCR result between submitted results and ocr results
2024-06-26 07:58:24 +00:00
logger.info(f" Starting generate csv to get performance")
2024-01-31 03:00:18 +00:00
gt_path = save_path
pred_path = save_path
req_to_red_dict = gen_req_to_red_dict(csv_path_end_user)
init_data = init_csv(gt_dir=gt_path, pred_dir=pred_path, req_to_red=req_to_red_dict)
pd.DataFrame(init_data).to_csv(os.path.join(save_csv, "init1.csv"), index=False)
2024-06-26 07:58:24 +00:00
logger.info(f" Done")
2024-01-31 03:00:18 +00:00
# # Step 4: Split requests whose accuracy is less than 1 to revise
2024-06-26 07:58:24 +00:00
# logger.info(f" Starting split data to review")
2024-01-31 03:00:18 +00:00
# revised_path = os.path.join(save_csv, "revised")
# # shutil.rmtree(revised_path)
# pick_sample_to_revise(ocr_accuracy=init_data, gt_dir=save_path, save_dir=revised_path)
2024-06-26 07:58:24 +00:00
# logger.info(f" Done")
2024-01-31 03:00:18 +00:00
# # Step 5: Merge revised results to gt folder
2024-06-26 07:58:24 +00:00
# logger.info(f" Merging revised data to ground truth folder")
2024-01-31 03:00:18 +00:00
# revised_path = os.path.join(save_csv, "revised")
# revised_path = [f'{revised_path}/empty_results', f'{revised_path}/diff_results']
# merge_revised_sample(revised_path_list=revised_path, save_dir=save_path)
# print("Done")
# # Step 6: Caculate OCR result between ocr results and revised results
2024-06-26 07:58:24 +00:00
# logger.info(f" Exporting OCR report")
2024-01-31 03:00:18 +00:00
# init_csv_path = os.path.join(save_csv, "init1.csv")
# report = export_report(init_csv=init_csv_path)
# error_path = os.path.join(save_csv, "errors")
# pick_sample_to_revise(ocr_accuracy=report[report.accuracy < 0.75].to_dict('records'), gt_dir=save_path, save_dir=error_path)
# n_total_images = len(report)
# n_bad_images = len(report[report.accuracy < 0.75])
# average_acc = report[report.accuracy >= 0.75]['accuracy'].mean()
# print("Total requests:", len(report['requestId'].unique()))
# print("Total images:", n_total_images)
# print("No. imei images:", len(report[report.class_name == "imei_number"]))
# print("No. invoice images:", len(report[report.class_name == "retailername"]))
# print("No. bad quality images:", n_bad_images)
# print("No. valid images:", n_total_images - n_bad_images)
# print("No. per of bad quality images:", 100*n_bad_images/n_total_images)
# print("Average accuracy:", 100*average_acc)
# last_row = n_total_images
# report.at[last_row, "requestId"] = "Total requests:"
# report.at[last_row, "redemptionNumber"] = len(report['requestId'].unique())
# report.at[last_row+1, "requestId"] = "Total images:"
# report.at[last_row+1, "redemptionNumber"] = n_total_images
# report.at[last_row+2, "requestId"] = "No. imei images:"
# report.at[last_row+2, "redemptionNumber"] = len(report[report.class_name == "imei_number"])
# report.at[last_row+3, "requestId"] = "No. invoice images:"
# report.at[last_row+3, "redemptionNumber"] = len(report[report.class_name == "retailername"])
# report.at[last_row+4, "requestId"] = "No. bad quality images:"
# report.at[last_row+4, "redemptionNumber"] = n_bad_images
# report.at[last_row+5, "requestId"] = "No. valid images:"
# report.at[last_row+5, "redemptionNumber"] = n_total_images - n_bad_images
# report.at[last_row+6, "requestId"] = "No. per of bad quality images:"
# report.at[last_row+6, "redemptionNumber"] = 100*n_bad_images/n_total_images
# report.at[last_row+7, "requestId"] = "Average accuracy:"
# report.at[last_row+7, "redemptionNumber"] = 100*average_acc
# report.drop(columns=["file_path", "class_name"]).to_csv(os.path.join(save_csv, f"SBT_report_{time.strftime('%Y%m%d')}.csv"), index=False)
2024-06-26 07:58:24 +00:00
# logger.info(f" Done")
2024-01-31 03:00:18 +00:00