386 lines
13 KiB
Python
386 lines
13 KiB
Python
|
import re
|
||
|
from pathlib import Path
|
||
|
from difflib import SequenceMatcher
|
||
|
from terminaltables import AsciiTable
|
||
|
from rapidfuzz.distance import Levenshtein
|
||
|
|
||
|
from .wiki_diff import inline_diff
|
||
|
|
||
|
|
||
|
def is_type_list(x, type):
|
||
|
|
||
|
if not isinstance(x, list):
|
||
|
return False
|
||
|
|
||
|
return all(isinstance(item, type) for item in x)
|
||
|
|
||
|
|
||
|
def cal_true_positive_char(pred, gt):
|
||
|
"""Calculate correct character number in prediction.
|
||
|
Args:
|
||
|
pred (str): Prediction text.
|
||
|
gt (str): Ground truth text.
|
||
|
Returns:
|
||
|
true_positive_char_num (int): The true positive number.
|
||
|
"""
|
||
|
|
||
|
all_opt = SequenceMatcher(None, pred, gt)
|
||
|
true_positive_char_num = 0
|
||
|
for opt, _, _, s2, e2 in all_opt.get_opcodes():
|
||
|
if opt == "equal":
|
||
|
true_positive_char_num += e2 - s2
|
||
|
else:
|
||
|
pass
|
||
|
return true_positive_char_num
|
||
|
|
||
|
|
||
|
def post_processing(text):
|
||
|
"""
|
||
|
- Remove special characters and extra spaces + lower case
|
||
|
"""
|
||
|
|
||
|
text = re.sub(
|
||
|
r"[^aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789 ]",
|
||
|
" ",
|
||
|
text,
|
||
|
)
|
||
|
text = re.sub(r"\s\s+", " ", text)
|
||
|
text = text.strip()
|
||
|
|
||
|
return text
|
||
|
|
||
|
|
||
|
def count_matches(pred_texts, gt_texts, use_ignore=True):
|
||
|
"""Count the various match number for metric calculation.
|
||
|
Args:
|
||
|
pred_texts (list[str]): Predicted text string.
|
||
|
gt_texts (list[str]): Ground truth text string.
|
||
|
Returns:
|
||
|
match_res: (dict[str: int]): Match number used for
|
||
|
metric calculation.
|
||
|
"""
|
||
|
match_res = {
|
||
|
"gt_char_num": 0,
|
||
|
"pred_char_num": 0,
|
||
|
"true_positive_char_num": 0,
|
||
|
"gt_word_num": 0,
|
||
|
"match_word_num": 0,
|
||
|
"match_word_ignore_case": 0,
|
||
|
"match_word_ignore_case_symbol": 0,
|
||
|
"match_kie": 0,
|
||
|
"match_kie_ignore_case": 0,
|
||
|
}
|
||
|
# comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
|
||
|
# comp = re.compile('[]')
|
||
|
norm_ed_sum = 0.0
|
||
|
|
||
|
gt_texts_for_ned_word = []
|
||
|
pred_texts_for_ned_word = []
|
||
|
for pred_text, gt_text in zip(pred_texts, gt_texts):
|
||
|
if gt_text == pred_text:
|
||
|
match_res["match_word_num"] += 1
|
||
|
match_res["match_kie"] += 1
|
||
|
gt_text_lower = str(gt_text).lower()
|
||
|
pred_text_lower = str(pred_text).lower()
|
||
|
|
||
|
if gt_text_lower == pred_text_lower:
|
||
|
match_res["match_word_ignore_case"] += 1
|
||
|
|
||
|
# gt_text_lower_ignore = comp.sub('', gt_text_lower)
|
||
|
# pred_text_lower_ignore = comp.sub('', pred_text_lower)
|
||
|
if use_ignore:
|
||
|
gt_text_lower_ignore = post_processing(gt_text_lower)
|
||
|
pred_text_lower_ignore = post_processing(pred_text_lower)
|
||
|
else:
|
||
|
gt_text_lower_ignore = gt_text_lower
|
||
|
pred_text_lower_ignore = pred_text_lower
|
||
|
|
||
|
if gt_text_lower_ignore == pred_text_lower_ignore:
|
||
|
match_res["match_kie_ignore_case"] += 1
|
||
|
|
||
|
gt_texts_for_ned_word.append(gt_text_lower_ignore.split(" "))
|
||
|
pred_texts_for_ned_word.append(pred_text_lower_ignore.split(" "))
|
||
|
|
||
|
match_res["gt_word_num"] += 1
|
||
|
|
||
|
norm_ed = Levenshtein.normalized_distance(
|
||
|
pred_text_lower_ignore, gt_text_lower_ignore
|
||
|
)
|
||
|
# if norm_ed > 0.1:
|
||
|
# print(gt_text_lower_ignore, pred_text_lower_ignore, sep='\n')
|
||
|
# print("-"*20)
|
||
|
norm_ed_sum += norm_ed
|
||
|
|
||
|
# number to calculate char level recall & precision
|
||
|
match_res["gt_char_num"] += len(gt_text_lower_ignore)
|
||
|
match_res["pred_char_num"] += len(pred_text_lower_ignore)
|
||
|
true_positive_char_num = cal_true_positive_char(
|
||
|
pred_text_lower_ignore, gt_text_lower_ignore
|
||
|
)
|
||
|
match_res["true_positive_char_num"] += true_positive_char_num
|
||
|
|
||
|
normalized_edit_distance = norm_ed_sum / max(1, len(gt_texts))
|
||
|
match_res["ned"] = normalized_edit_distance
|
||
|
|
||
|
# NED for word-level
|
||
|
norm_ed_word_sum = 0.0
|
||
|
# print(pred_texts_for_ned_word[0])
|
||
|
unique_words = list(
|
||
|
set(
|
||
|
[x for line in pred_texts_for_ned_word for x in line]
|
||
|
+ [x for line in gt_texts_for_ned_word for x in line]
|
||
|
)
|
||
|
)
|
||
|
preds = [
|
||
|
[unique_words.index(w) for w in pred_text_for_ned_word]
|
||
|
for pred_text_for_ned_word in pred_texts_for_ned_word
|
||
|
]
|
||
|
truths = [
|
||
|
[unique_words.index(w) for w in gt_text_for_ned_word]
|
||
|
for gt_text_for_ned_word in gt_texts_for_ned_word
|
||
|
]
|
||
|
for pred_text, gt_text in zip(preds, truths):
|
||
|
norm_ed_word = Levenshtein.normalized_distance(pred_text, gt_text)
|
||
|
# if norm_ed_word < 0.2:
|
||
|
# print(pred_text, gt_text)
|
||
|
norm_ed_word_sum += norm_ed_word
|
||
|
|
||
|
normalized_edit_distance_word = norm_ed_word_sum / max(1, len(gt_texts))
|
||
|
match_res["ned_word"] = normalized_edit_distance_word
|
||
|
|
||
|
return match_res
|
||
|
|
||
|
|
||
|
def eval_ocr_metric(pred_texts, gt_texts, metric="acc"):
|
||
|
"""Evaluate the text recognition performance with metric: word accuracy and
|
||
|
1-N.E.D. See https://rrc.cvc.uab.es/?ch=14&com=tasks for details.
|
||
|
Args:
|
||
|
pred_texts (list[str]): Text strings of prediction.
|
||
|
gt_texts (list[str]): Text strings of ground truth.
|
||
|
metric (str | list[str]): Metric(s) to be evaluated. Options are:
|
||
|
- 'word_acc': Accuracy at word level.
|
||
|
- 'word_acc_ignore_case': Accuracy at word level, ignoring letter
|
||
|
case.
|
||
|
- 'word_acc_ignore_case_symbol': Accuracy at word level, ignoring
|
||
|
letter case and symbol. (Default metric for academic evaluation)
|
||
|
- 'char_recall': Recall at character level, ignoring
|
||
|
letter case and symbol.
|
||
|
- 'char_precision': Precision at character level, ignoring
|
||
|
letter case and symbol.
|
||
|
- 'one_minus_ned': 1 - normalized_edit_distance
|
||
|
In particular, if ``metric == 'acc'``, results on all metrics above
|
||
|
will be reported.
|
||
|
Returns:
|
||
|
dict{str: float}: Result dict for text recognition, keys could be some
|
||
|
of the following: ['word_acc', 'word_acc_ignore_case',
|
||
|
'word_acc_ignore_case_symbol', 'char_recall', 'char_precision',
|
||
|
'1-N.E.D'].
|
||
|
"""
|
||
|
assert isinstance(pred_texts, list)
|
||
|
assert isinstance(gt_texts, list)
|
||
|
assert len(pred_texts) == len(gt_texts)
|
||
|
|
||
|
assert isinstance(metric, str) or is_type_list(metric, str)
|
||
|
if metric == "acc" or metric == ["acc"]:
|
||
|
metric = [
|
||
|
"word_acc",
|
||
|
"word_acc_ignore_case",
|
||
|
"word_acc_ignore_case_symbol",
|
||
|
"char_recall",
|
||
|
"char_precision",
|
||
|
"one_minus_ned",
|
||
|
]
|
||
|
metric = set([metric]) if isinstance(metric, str) else set(metric)
|
||
|
|
||
|
# supported_metrics = set([
|
||
|
# 'word_acc', 'word_acc_ignore_case', 'word_acc_ignore_case_symbol',
|
||
|
# 'char_recall', 'char_precision', 'one_minus_ned', 'one_minust_ned_word'
|
||
|
# ])
|
||
|
# assert metric.issubset(supported_metrics)
|
||
|
|
||
|
match_res = count_matches(pred_texts, gt_texts)
|
||
|
eps = 1e-8
|
||
|
eval_res = {}
|
||
|
|
||
|
if "char_recall" in metric:
|
||
|
char_recall = (
|
||
|
1.0 * match_res["true_positive_char_num"] / (eps + match_res["gt_char_num"])
|
||
|
)
|
||
|
eval_res["char_recall"] = char_recall
|
||
|
|
||
|
if "char_precision" in metric:
|
||
|
char_precision = (
|
||
|
1.0
|
||
|
* match_res["true_positive_char_num"]
|
||
|
/ (eps + match_res["pred_char_num"])
|
||
|
)
|
||
|
eval_res["char_precision"] = char_precision
|
||
|
|
||
|
if "word_acc" in metric:
|
||
|
word_acc = 1.0 * match_res["match_word_num"] / (eps + match_res["gt_word_num"])
|
||
|
eval_res["word_acc"] = word_acc
|
||
|
|
||
|
if "word_acc_ignore_case" in metric:
|
||
|
word_acc_ignore_case = (
|
||
|
1.0 * match_res["match_word_ignore_case"] / (eps + match_res["gt_word_num"])
|
||
|
)
|
||
|
eval_res["word_acc_ignore_case"] = word_acc_ignore_case
|
||
|
|
||
|
if "word_acc_ignore_case_symbol" in metric:
|
||
|
word_acc_ignore_case_symbol = (
|
||
|
1.0
|
||
|
* match_res["match_word_ignore_case_symbol"]
|
||
|
/ (eps + match_res["gt_word_num"])
|
||
|
)
|
||
|
eval_res["word_acc_ignore_case_symbol"] = word_acc_ignore_case_symbol
|
||
|
|
||
|
if "one_minus_ned" in metric:
|
||
|
|
||
|
eval_res["1-N.E.D"] = 1.0 - match_res["ned"]
|
||
|
|
||
|
if "one_minus_ned_word" in metric:
|
||
|
|
||
|
eval_res["1-N.E.D_word"] = 1.0 - match_res["ned_word"]
|
||
|
|
||
|
if "line_acc_ignore_case_symbol" in metric:
|
||
|
line_acc_ignore_case_symbol = (
|
||
|
1.0 * match_res["match_kie_ignore_case"] / (eps + match_res["gt_word_num"])
|
||
|
)
|
||
|
eval_res["line_acc_ignore_case_symbol"] = line_acc_ignore_case_symbol
|
||
|
|
||
|
if "line_acc" in metric:
|
||
|
word_acc_ignore_case_symbol = (
|
||
|
1.0 * match_res["match_kie"] / (eps + match_res["gt_word_num"])
|
||
|
)
|
||
|
eval_res["line_acc"] = word_acc_ignore_case_symbol
|
||
|
|
||
|
for key, value in eval_res.items():
|
||
|
eval_res[key] = float("{:.4f}".format(value))
|
||
|
|
||
|
return eval_res
|
||
|
|
||
|
|
||
|
def eval_kie(preds_e2e: dict[str, dict[str, str]], gt_e2e: dict[str, dict[str, str]], labels, skip_labels=[]):
|
||
|
|
||
|
results = {label: 1 for label in labels}
|
||
|
pred_texts_dict = {label: [] for label in labels}
|
||
|
gt_texts_dict = {label: [] for label in labels}
|
||
|
fail_cases = {}
|
||
|
for img_id in gt_e2e.keys():
|
||
|
fail_cases[img_id] = {}
|
||
|
pred_items = preds_e2e.get(img_id, {k: '' for k in gt_e2e[img_id]})
|
||
|
gt_items = gt_e2e[img_id]
|
||
|
|
||
|
for class_name, text_gt in gt_items.items():
|
||
|
if class_name in skip_labels:
|
||
|
continue
|
||
|
# if class_name == 'seller_name_value':
|
||
|
# print(gt_items)
|
||
|
if class_name not in pred_items:
|
||
|
text_pred = ""
|
||
|
else:
|
||
|
text_pred = pred_items[class_name]
|
||
|
|
||
|
if str(text_pred) != str(text_gt):
|
||
|
diff = inline_diff(text_pred, text_gt)
|
||
|
fail_cases[img_id][class_name] = {
|
||
|
'pred': text_pred,
|
||
|
'gt': text_gt,
|
||
|
"diff": diff['res_text'],
|
||
|
"ned": diff["ned"],
|
||
|
"score": eval_ocr_metric([text_pred], [text_gt], metric=[
|
||
|
"one_minus_ned"])["1-N.E.D"],
|
||
|
}
|
||
|
|
||
|
pred_texts_dict[class_name].append(text_pred)
|
||
|
gt_texts_dict[class_name].append(text_gt)
|
||
|
|
||
|
for class_name in labels:
|
||
|
pred_texts = pred_texts_dict[class_name]
|
||
|
gt_texts = gt_texts_dict[class_name]
|
||
|
result = eval_ocr_metric(
|
||
|
pred_texts,
|
||
|
gt_texts,
|
||
|
metric=[
|
||
|
"one_minus_ned",
|
||
|
"line_acc_ignore_case_symbol",
|
||
|
"line_acc",
|
||
|
"one_minus_ned_word",
|
||
|
],
|
||
|
)
|
||
|
results[class_name] = {
|
||
|
"1-ned": result["1-N.E.D"],
|
||
|
"1-ned-word": result["1-N.E.D_word"],
|
||
|
"line_acc": result["line_acc"],
|
||
|
"line_acc_ignore_case_symbol": result["line_acc_ignore_case_symbol"],
|
||
|
"samples": len(pred_texts),
|
||
|
}
|
||
|
|
||
|
# avg reusults
|
||
|
sum_1_ned = sum(
|
||
|
[
|
||
|
results[class_name]["1-ned"] * results[class_name]["samples"]
|
||
|
for class_name in labels
|
||
|
]
|
||
|
)
|
||
|
sum_1_ned_word = sum(
|
||
|
[
|
||
|
results[class_name]["1-ned-word"] * results[class_name]["samples"]
|
||
|
for class_name in labels
|
||
|
]
|
||
|
)
|
||
|
|
||
|
sum_line_acc = sum(
|
||
|
[
|
||
|
results[class_name]["line_acc"] * results[class_name]["samples"]
|
||
|
for class_name in labels
|
||
|
]
|
||
|
)
|
||
|
sum_line_acc_ignore_case_symbol = sum(
|
||
|
[
|
||
|
results[class_name]["line_acc_ignore_case_symbol"]
|
||
|
* results[class_name]["samples"]
|
||
|
for class_name in labels
|
||
|
]
|
||
|
)
|
||
|
|
||
|
total_samples = sum(
|
||
|
[results[class_name]["samples"] for class_name in labels]
|
||
|
)
|
||
|
results["avg_all"] = {
|
||
|
"1-ned": round(sum_1_ned / total_samples, 4),
|
||
|
"1-ned-word": round(sum_1_ned_word / total_samples, 4),
|
||
|
"line_acc": round(sum_line_acc / total_samples, 4),
|
||
|
"line_acc_ignore_case_symbol": round(
|
||
|
sum_line_acc_ignore_case_symbol / total_samples, 4
|
||
|
),
|
||
|
"samples": total_samples,
|
||
|
}
|
||
|
|
||
|
table_data = [
|
||
|
[
|
||
|
"class_name",
|
||
|
"1-NED",
|
||
|
"1-N.E.D_word",
|
||
|
"line-acc",
|
||
|
"line_acc_ignore_case_symbol",
|
||
|
"#samples",
|
||
|
]
|
||
|
]
|
||
|
for class_name in results.keys():
|
||
|
# if c < p.shape[0]:
|
||
|
table_data.append(
|
||
|
[
|
||
|
class_name,
|
||
|
results[class_name]["1-ned"],
|
||
|
results[class_name]["1-ned-word"],
|
||
|
results[class_name]["line_acc"],
|
||
|
results[class_name]["line_acc_ignore_case_symbol"],
|
||
|
results[class_name]["samples"],
|
||
|
]
|
||
|
)
|
||
|
|
||
|
table = AsciiTable(table_data)
|
||
|
print(table.table)
|
||
|
return results, fail_cases
|