sbt-idp/cope2n-api/fwd_api/utils/ocr_utils/ocr_metrics.py

386 lines
13 KiB
Python
Raw Normal View History

2024-01-31 03:00:18 +00:00
import re
from pathlib import Path
from difflib import SequenceMatcher
from terminaltables import AsciiTable
from rapidfuzz.distance import Levenshtein
from .wiki_diff import inline_diff
def is_type_list(x, type):
if not isinstance(x, list):
return False
return all(isinstance(item, type) for item in x)
def cal_true_positive_char(pred, gt):
"""Calculate correct character number in prediction.
Args:
pred (str): Prediction text.
gt (str): Ground truth text.
Returns:
true_positive_char_num (int): The true positive number.
"""
all_opt = SequenceMatcher(None, pred, gt)
true_positive_char_num = 0
for opt, _, _, s2, e2 in all_opt.get_opcodes():
if opt == "equal":
true_positive_char_num += e2 - s2
else:
pass
return true_positive_char_num
def post_processing(text):
"""
- Remove special characters and extra spaces + lower case
"""
text = re.sub(
r"[^aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789 ]",
" ",
text,
)
text = re.sub(r"\s\s+", " ", text)
text = text.strip()
return text
def count_matches(pred_texts, gt_texts, use_ignore=True):
"""Count the various match number for metric calculation.
Args:
pred_texts (list[str]): Predicted text string.
gt_texts (list[str]): Ground truth text string.
Returns:
match_res: (dict[str: int]): Match number used for
metric calculation.
"""
match_res = {
"gt_char_num": 0,
"pred_char_num": 0,
"true_positive_char_num": 0,
"gt_word_num": 0,
"match_word_num": 0,
"match_word_ignore_case": 0,
"match_word_ignore_case_symbol": 0,
"match_kie": 0,
"match_kie_ignore_case": 0,
}
# comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
# comp = re.compile('[]')
norm_ed_sum = 0.0
gt_texts_for_ned_word = []
pred_texts_for_ned_word = []
for pred_text, gt_text in zip(pred_texts, gt_texts):
if gt_text == pred_text:
match_res["match_word_num"] += 1
match_res["match_kie"] += 1
gt_text_lower = str(gt_text).lower()
pred_text_lower = str(pred_text).lower()
if gt_text_lower == pred_text_lower:
match_res["match_word_ignore_case"] += 1
# gt_text_lower_ignore = comp.sub('', gt_text_lower)
# pred_text_lower_ignore = comp.sub('', pred_text_lower)
if use_ignore:
gt_text_lower_ignore = post_processing(gt_text_lower)
pred_text_lower_ignore = post_processing(pred_text_lower)
else:
gt_text_lower_ignore = gt_text_lower
pred_text_lower_ignore = pred_text_lower
if gt_text_lower_ignore == pred_text_lower_ignore:
match_res["match_kie_ignore_case"] += 1
gt_texts_for_ned_word.append(gt_text_lower_ignore.split(" "))
pred_texts_for_ned_word.append(pred_text_lower_ignore.split(" "))
match_res["gt_word_num"] += 1
norm_ed = Levenshtein.normalized_distance(
pred_text_lower_ignore, gt_text_lower_ignore
)
# if norm_ed > 0.1:
# print(gt_text_lower_ignore, pred_text_lower_ignore, sep='\n')
# print("-"*20)
norm_ed_sum += norm_ed
# number to calculate char level recall & precision
match_res["gt_char_num"] += len(gt_text_lower_ignore)
match_res["pred_char_num"] += len(pred_text_lower_ignore)
true_positive_char_num = cal_true_positive_char(
pred_text_lower_ignore, gt_text_lower_ignore
)
match_res["true_positive_char_num"] += true_positive_char_num
normalized_edit_distance = norm_ed_sum / max(1, len(gt_texts))
match_res["ned"] = normalized_edit_distance
# NED for word-level
norm_ed_word_sum = 0.0
# print(pred_texts_for_ned_word[0])
unique_words = list(
set(
[x for line in pred_texts_for_ned_word for x in line]
+ [x for line in gt_texts_for_ned_word for x in line]
)
)
preds = [
[unique_words.index(w) for w in pred_text_for_ned_word]
for pred_text_for_ned_word in pred_texts_for_ned_word
]
truths = [
[unique_words.index(w) for w in gt_text_for_ned_word]
for gt_text_for_ned_word in gt_texts_for_ned_word
]
for pred_text, gt_text in zip(preds, truths):
norm_ed_word = Levenshtein.normalized_distance(pred_text, gt_text)
# if norm_ed_word < 0.2:
# print(pred_text, gt_text)
norm_ed_word_sum += norm_ed_word
normalized_edit_distance_word = norm_ed_word_sum / max(1, len(gt_texts))
match_res["ned_word"] = normalized_edit_distance_word
return match_res
def eval_ocr_metric(pred_texts, gt_texts, metric="acc"):
"""Evaluate the text recognition performance with metric: word accuracy and
1-N.E.D. See https://rrc.cvc.uab.es/?ch=14&com=tasks for details.
Args:
pred_texts (list[str]): Text strings of prediction.
gt_texts (list[str]): Text strings of ground truth.
metric (str | list[str]): Metric(s) to be evaluated. Options are:
- 'word_acc': Accuracy at word level.
- 'word_acc_ignore_case': Accuracy at word level, ignoring letter
case.
- 'word_acc_ignore_case_symbol': Accuracy at word level, ignoring
letter case and symbol. (Default metric for academic evaluation)
- 'char_recall': Recall at character level, ignoring
letter case and symbol.
- 'char_precision': Precision at character level, ignoring
letter case and symbol.
- 'one_minus_ned': 1 - normalized_edit_distance
In particular, if ``metric == 'acc'``, results on all metrics above
will be reported.
Returns:
dict{str: float}: Result dict for text recognition, keys could be some
of the following: ['word_acc', 'word_acc_ignore_case',
'word_acc_ignore_case_symbol', 'char_recall', 'char_precision',
'1-N.E.D'].
"""
assert isinstance(pred_texts, list)
assert isinstance(gt_texts, list)
assert len(pred_texts) == len(gt_texts)
assert isinstance(metric, str) or is_type_list(metric, str)
if metric == "acc" or metric == ["acc"]:
metric = [
"word_acc",
"word_acc_ignore_case",
"word_acc_ignore_case_symbol",
"char_recall",
"char_precision",
"one_minus_ned",
]
metric = set([metric]) if isinstance(metric, str) else set(metric)
# supported_metrics = set([
# 'word_acc', 'word_acc_ignore_case', 'word_acc_ignore_case_symbol',
# 'char_recall', 'char_precision', 'one_minus_ned', 'one_minust_ned_word'
# ])
# assert metric.issubset(supported_metrics)
match_res = count_matches(pred_texts, gt_texts)
eps = 1e-8
eval_res = {}
if "char_recall" in metric:
char_recall = (
1.0 * match_res["true_positive_char_num"] / (eps + match_res["gt_char_num"])
)
eval_res["char_recall"] = char_recall
if "char_precision" in metric:
char_precision = (
1.0
* match_res["true_positive_char_num"]
/ (eps + match_res["pred_char_num"])
)
eval_res["char_precision"] = char_precision
if "word_acc" in metric:
word_acc = 1.0 * match_res["match_word_num"] / (eps + match_res["gt_word_num"])
eval_res["word_acc"] = word_acc
if "word_acc_ignore_case" in metric:
word_acc_ignore_case = (
1.0 * match_res["match_word_ignore_case"] / (eps + match_res["gt_word_num"])
)
eval_res["word_acc_ignore_case"] = word_acc_ignore_case
if "word_acc_ignore_case_symbol" in metric:
word_acc_ignore_case_symbol = (
1.0
* match_res["match_word_ignore_case_symbol"]
/ (eps + match_res["gt_word_num"])
)
eval_res["word_acc_ignore_case_symbol"] = word_acc_ignore_case_symbol
if "one_minus_ned" in metric:
eval_res["1-N.E.D"] = 1.0 - match_res["ned"]
if "one_minus_ned_word" in metric:
eval_res["1-N.E.D_word"] = 1.0 - match_res["ned_word"]
if "line_acc_ignore_case_symbol" in metric:
line_acc_ignore_case_symbol = (
1.0 * match_res["match_kie_ignore_case"] / (eps + match_res["gt_word_num"])
)
eval_res["line_acc_ignore_case_symbol"] = line_acc_ignore_case_symbol
if "line_acc" in metric:
word_acc_ignore_case_symbol = (
1.0 * match_res["match_kie"] / (eps + match_res["gt_word_num"])
)
eval_res["line_acc"] = word_acc_ignore_case_symbol
for key, value in eval_res.items():
eval_res[key] = float("{:.4f}".format(value))
return eval_res
def eval_kie(preds_e2e: dict[str, dict[str, str]], gt_e2e: dict[str, dict[str, str]], labels, skip_labels=[]):
results = {label: 1 for label in labels}
pred_texts_dict = {label: [] for label in labels}
gt_texts_dict = {label: [] for label in labels}
fail_cases = {}
for img_id in gt_e2e.keys():
fail_cases[img_id] = {}
pred_items = preds_e2e.get(img_id, {k: '' for k in gt_e2e[img_id]})
gt_items = gt_e2e[img_id]
for class_name, text_gt in gt_items.items():
if class_name in skip_labels:
continue
# if class_name == 'seller_name_value':
# print(gt_items)
if class_name not in pred_items:
text_pred = ""
else:
text_pred = pred_items[class_name]
if str(text_pred) != str(text_gt):
diff = inline_diff(text_pred, text_gt)
fail_cases[img_id][class_name] = {
'pred': text_pred,
'gt': text_gt,
"diff": diff['res_text'],
"ned": diff["ned"],
"score": eval_ocr_metric([text_pred], [text_gt], metric=[
"one_minus_ned"])["1-N.E.D"],
}
pred_texts_dict[class_name].append(text_pred)
gt_texts_dict[class_name].append(text_gt)
for class_name in labels:
pred_texts = pred_texts_dict[class_name]
gt_texts = gt_texts_dict[class_name]
result = eval_ocr_metric(
pred_texts,
gt_texts,
metric=[
"one_minus_ned",
"line_acc_ignore_case_symbol",
"line_acc",
"one_minus_ned_word",
],
)
results[class_name] = {
"1-ned": result["1-N.E.D"],
"1-ned-word": result["1-N.E.D_word"],
"line_acc": result["line_acc"],
"line_acc_ignore_case_symbol": result["line_acc_ignore_case_symbol"],
"samples": len(pred_texts),
}
# avg reusults
sum_1_ned = sum(
[
results[class_name]["1-ned"] * results[class_name]["samples"]
for class_name in labels
]
)
sum_1_ned_word = sum(
[
results[class_name]["1-ned-word"] * results[class_name]["samples"]
for class_name in labels
]
)
sum_line_acc = sum(
[
results[class_name]["line_acc"] * results[class_name]["samples"]
for class_name in labels
]
)
sum_line_acc_ignore_case_symbol = sum(
[
results[class_name]["line_acc_ignore_case_symbol"]
* results[class_name]["samples"]
for class_name in labels
]
)
total_samples = sum(
[results[class_name]["samples"] for class_name in labels]
)
results["avg_all"] = {
"1-ned": round(sum_1_ned / total_samples, 4),
"1-ned-word": round(sum_1_ned_word / total_samples, 4),
"line_acc": round(sum_line_acc / total_samples, 4),
"line_acc_ignore_case_symbol": round(
sum_line_acc_ignore_case_symbol / total_samples, 4
),
"samples": total_samples,
}
table_data = [
[
"class_name",
"1-NED",
"1-N.E.D_word",
"line-acc",
"line_acc_ignore_case_symbol",
"#samples",
]
]
for class_name in results.keys():
# if c < p.shape[0]:
table_data.append(
[
class_name,
results[class_name]["1-ned"],
results[class_name]["1-ned-word"],
results[class_name]["line_acc"],
results[class_name]["line_acc_ignore_case_symbol"],
results[class_name]["samples"],
]
)
table = AsciiTable(table_data)
print(table.table)
return results, fail_cases