sbt-idp/cope2n-api/fwd_api/utils/ocr_utils/wiki_diff.py
2024-06-26 14:58:24 +07:00

205 lines
7.6 KiB
Python

# https://stackoverflow.com/questions/774316/python-difflib-highlighting-differences-inline
import difflib
import unidecode
import os
import glob
import pandas as pd
import logging
logger = logging.getLogger(__name__)
VOWELS = 'aeouiy' + 'AEOUIY'
CONSONANTS = 'bcdfghjklmnpqrstvxwz' + 'BCDFGHJKLMNPQRSTVXWZ'
# PREDICT_PATH = 'ocr/result'
# GROUNDTRUTH_PATH = '/mnt/hdd2T/AICR/Datasets/wiki/ground_truth'
PREDICT_PATH = 'ocr/result/cinamon'
GROUNDTRUTH_PATH = '/mnt/hdd2T/AICR/Datasets/Backup/1.Hand_writing/Lines/cinnamon_data'
# note that we also use different preprocess for cinamon data
# SAVE_PATH = 'wiki_diff'
SAVE_PATH = 'wiki_diff/cinamon'
RES_PATH = f'{SAVE_PATH}/result/'
WRONG_ACCENT_FILE = f'{SAVE_PATH}/wrong_accent.txt'
LOST_ACCENT_FILE = f'{SAVE_PATH}/lost_accent.txt'
TOTAL_WORD = 0
def write_accent_error(path, err):
# path should be wrong_accent_file or lost_accent_file
with open(path, 'a') as f:
f.write(err)
f.write('\n')
def update_ddata_specialchars(ddata_specialchars, correction_key, char_key):
if char_key in ddata_specialchars[correction_key]:
ddata_specialchars[correction_key][char_key] += 1
else:
ddata_specialchars[correction_key][char_key] = 1
def process_replace_tag(matcher, i1, i2, j1, j2, ddata, ddata_specialchars):
a_char = matcher.a[i1:i2]
b_char = matcher.b[j1:j2]
ddata['res_text'] += ' ### {' + a_char + ' -> ' + b_char + '} ### '
ddata['nwrongs'] += 1*len(b_char)
if len(a_char) == 1 and len(b_char) == 1: # single char case
if a_char.lower() == b_char.lower(): # wrong upper/lower case
ddata['UL_single'] += 1
update_ddata_specialchars(ddata_specialchars, 'UL', (a_char, b_char))
else:
ddata['nwrongs_single'] += 1
a_ori = unidecode.unidecode(a_char).lower()
b_ori = unidecode.unidecode(b_char).lower()
if a_ori in VOWELS and b_ori in VOWELS:
if a_ori == b_ori:
err = a_char + ' -> ' + b_char
if b_ori == b_char.lower(): # e.g. Ơ -> O
ddata['nlost_accent'] += 1
# write_accent_error(LOST_ACCENT_FILE, err)
else: # e.g Ơ -> Ớ
ddata['nwrong_accent'] += 1
# write_accent_error(WRONG_ACCENT_FILE, err)
else: # e.g Ă -> Â
ddata['nwrong_vowels'] += 1
else:
if a_ori in CONSONANTS and b_ori in CONSONANTS:
ddata['nwrong_consonants'] += 1
else:
ddata['nwrong_specialchars'] += 1
update_ddata_specialchars(ddata_specialchars, 'wrong', (a_char, b_char))
else:
if a_char.lower() == b_char.lower():
ddata['UL_multiple'] += 1
update_ddata_specialchars(ddata_specialchars, 'UL', (a_char, b_char))
else:
ddata['nwrongs_multiple'] += 1
if len(a_char) > 10 or len(b_char) > 10:
ddata['nlong_sequences'] += 1
# print(a_char)
def process_delete_tag(matcher, i1, i2, ddata, ddata_specialchars):
a_char = matcher.a[i1:i2]
ddata['res_text'] += ' ### {- ' + a_char + '} ### '
ddata['nadds'] += 1*len(a_char)
if len(a_char) == 1:
ddata['nadds_single'] += 1
if a_char.lower() in CONSONANTS + VOWELS:
ddata['nadds_chars'] += 1
else:
if a_char == ' ':
ddata['nadds_space'] += 1
else:
ddata['nadds_specialchars'] += 1
update_ddata_specialchars(ddata_specialchars, 'add', a_char)
else:
ddata['nadds_multiple'] += 1
if len(a_char) > 10:
ddata['nlong_sequences'] += 1
# print(a_char)
def process_insert_tag(matcher, j1, j2, ddata, ddata_specialchars):
b_char = matcher.b[j1:j2]
ddata['nlosts'] += 1*len(b_char)
ddata['res_text'] += ' ### {+ ' + b_char + '} ### '
if len(b_char) == 1:
ddata['nlosts_single'] += 1
if b_char.lower() in CONSONANTS + VOWELS:
ddata['nlosts_chars'] += 1
else:
if b_char == ' ':
ddata['nlosts_space'] += 1
else:
ddata['nlosts_specialchars'] += 1
update_ddata_specialchars(ddata_specialchars, 'lost', b_char)
else:
ddata['nlosts_multiple'] += 1
if len(b_char) > 10:
ddata['nlong_sequences'] += 1
# print(b_char)
def inline_diff(a, b, ddata_specialchars={'lost': {}, 'add': {}, 'wrong': {}, 'UL': {}}):
matcher = difflib.SequenceMatcher(None, a, b)
ddata = {'res_text': ''}
# ddata = ddata | {key: 0 for key in ['nsingle', 'nmultiple']}
ddata = ddata | {key: 0 for key in ['UL_single', 'UL_multiple']}
ddata = ddata | {
key: 0 for key in
['nlosts', 'nlosts_single', 'nlosts_multiple', 'nlosts_chars', 'nlosts_specialchars', 'nlosts_space']}
ddata = ddata | {
key: 0 for key in
['nadds', 'nadds_single', 'nadds_multiple', 'nadds_chars', 'nadds_specialchars', 'nadds_space']}
ddata = ddata | {
key: 0 for key in
['nwrongs', 'nwrongs_single', 'nwrongs_multiple', 'nwrong_accent', 'nlost_accent', 'nwrong_vowels',
'nwrong_consonants', 'nwrong_specialchars']}
ddata['nlong_sequences'] = 0
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == 'replace': # wrong
process_replace_tag(matcher, i1, i2, j1, j2, ddata, ddata_specialchars)
if tag == 'delete': # OCR add char so the matcher "delete"
process_delete_tag(matcher, i1, i2, ddata, ddata_specialchars)
if tag == 'equal':
ddata['res_text'] += matcher.a[i1:i2]
if tag == 'insert': # OCR lost char so the matcher "insert"
process_insert_tag(matcher, j1, j2, ddata, ddata_specialchars)
ddata["ned"] = ddata['nwrongs'] + ddata['nadds'] + ddata['nlosts']
return ddata
def process_single_file(file_name, ddata_specialchars):
# read predict file
with open(os.path.join(PREDICT_PATH, file_name), 'r') as f:
predict = f.readlines()[0].strip()
# predict = ''.join(predict)
# predict = predict.replace(' ', '')
# predict = predict.replace('\n', '')
# print(predict)
# read groundtruth file
with open(os.path.join(GROUNDTRUTH_PATH, file_name), 'r') as f:
gt = f.readlines()[0].strip()
# gt = ''.join(gt)
# gt = gt.replace('\n', '')
# get statiscal data of difference between predict and ground truth
ddata = inline_diff(predict, gt, ddata_specialchars)
global TOTAL_WORD
TOTAL_WORD = TOTAL_WORD + len(gt.split())
# write to save_path
res_text = ddata.pop('res_text', None)
save_file = os.path.join(RES_PATH, file_name)
with open(save_file, 'w') as f:
f.write(res_text)
# generate csv file
ddata = {'file_name': save_file} | ddata
return ddata
def main(overwrite=False):
for accent_file in [WRONG_ACCENT_FILE, LOST_ACCENT_FILE]:
if os.path.exists(accent_file):
os.remove(accent_file)
lddata = []
ddata_specialchars = {'lost': {}, 'add': {}, 'wrong': {}, 'UL': {}}
for file_ in glob.glob(f'{PREDICT_PATH}/*.txt'):
file_name = file_.split('/')[-1]
ddata = process_single_file(file_name, ddata_specialchars)
lddata.append(ddata)
if overwrite:
df = pd.DataFrame(lddata)
df.to_csv(f'{SAVE_PATH}/wiki_diff.csv', sep='\t')
df_ = pd.DataFrame(ddata_specialchars)
df_.to_csv(f'{SAVE_PATH}/wiki_diff_specialchars.csv', sep='\t')
logger.info(TOTAL_WORD)
if __name__ == '__main__':
main(overwrite=True)