# https://stackoverflow.com/questions/774316/python-difflib-highlighting-differences-inline
import difflib
import unidecode
import os
import glob
import pandas as pd

VOWELS = 'aeouiy' + 'AEOUIY'
CONSONANTS = 'bcdfghjklmnpqrstvxwz' + 'BCDFGHJKLMNPQRSTVXWZ'
# PREDICT_PATH = 'ocr/result'
# GROUNDTRUTH_PATH = '/mnt/hdd2T/AICR/Datasets/wiki/ground_truth'
PREDICT_PATH = 'ocr/result/cinamon'
GROUNDTRUTH_PATH = '/mnt/hdd2T/AICR/Datasets/Backup/1.Hand_writing/Lines/cinnamon_data'
# note that we also use different preprocess for cinamon data
# SAVE_PATH = 'wiki_diff'
SAVE_PATH = 'wiki_diff/cinamon'
RES_PATH = f'{SAVE_PATH}/result/'
WRONG_ACCENT_FILE = f'{SAVE_PATH}/wrong_accent.txt'
LOST_ACCENT_FILE = f'{SAVE_PATH}/lost_accent.txt'
TOTAL_WORD = 0


def write_accent_error(path, err):
    # path should be wrong_accent_file or lost_accent_file
    with open(path, 'a') as f:
        f.write(err)
        f.write('\n')


def update_ddata_specialchars(ddata_specialchars, correction_key, char_key):
    if char_key in ddata_specialchars[correction_key]:
        ddata_specialchars[correction_key][char_key] += 1
    else:
        ddata_specialchars[correction_key][char_key] = 1


def process_replace_tag(matcher, i1, i2, j1, j2, ddata, ddata_specialchars):
    a_char = matcher.a[i1:i2]
    b_char = matcher.b[j1:j2]
    ddata['res_text'] += ' ### {' + a_char + ' -> ' + b_char + '} ### '
    ddata['nwrongs'] += 1*len(b_char)
    if len(a_char) == 1 and len(b_char) == 1:  # single char case
        if a_char.lower() == b_char.lower():  # wrong upper/lower case
            ddata['UL_single'] += 1
            update_ddata_specialchars(ddata_specialchars, 'UL', (a_char, b_char))
        else:
            ddata['nwrongs_single'] += 1
            a_ori = unidecode.unidecode(a_char).lower()
            b_ori = unidecode.unidecode(b_char).lower()
            if a_ori in VOWELS and b_ori in VOWELS:
                if a_ori == b_ori:
                    err = a_char + ' -> ' + b_char
                    if b_ori == b_char.lower():  # e.g. Ơ -> O
                        ddata['nlost_accent'] += 1
                        # write_accent_error(LOST_ACCENT_FILE, err)
                    else:  # e.g Ơ -> Ớ
                        ddata['nwrong_accent'] += 1
                        # write_accent_error(WRONG_ACCENT_FILE, err)
                else:  # e.g Ă -> Â
                    ddata['nwrong_vowels'] += 1
            else:
                if a_ori in CONSONANTS and b_ori in CONSONANTS:
                    ddata['nwrong_consonants'] += 1
                else:
                    ddata['nwrong_specialchars'] += 1
                    update_ddata_specialchars(ddata_specialchars, 'wrong', (a_char, b_char))
    else:
        if a_char.lower() == b_char.lower():
            ddata['UL_multiple'] += 1
            update_ddata_specialchars(ddata_specialchars, 'UL', (a_char, b_char))
        else:
            ddata['nwrongs_multiple'] += 1
            if len(a_char) > 10 or len(b_char) > 10:
                ddata['nlong_sequences'] += 1
                # print(a_char)


def process_delete_tag(matcher, i1, i2, ddata, ddata_specialchars):
    a_char = matcher.a[i1:i2]
    ddata['res_text'] += ' ### {- ' + a_char + '} ### '
    ddata['nadds'] += 1*len(a_char)
    if len(a_char) == 1:
        ddata['nadds_single'] += 1
        if a_char.lower() in CONSONANTS + VOWELS:
            ddata['nadds_chars'] += 1
        else:
            if a_char == ' ':
                ddata['nadds_space'] += 1
            else:
                ddata['nadds_specialchars'] += 1
                update_ddata_specialchars(ddata_specialchars, 'add', a_char)

    else:
        ddata['nadds_multiple'] += 1
        if len(a_char) > 10:
            ddata['nlong_sequences'] += 1
            # print(a_char)


def process_insert_tag(matcher, j1, j2, ddata, ddata_specialchars):
    b_char = matcher.b[j1:j2]
    ddata['nlosts'] += 1*len(b_char)
    ddata['res_text'] += ' ### {+ ' + b_char + '} ### '
    if len(b_char) == 1:
        ddata['nlosts_single'] += 1
        if b_char.lower() in CONSONANTS + VOWELS:
            ddata['nlosts_chars'] += 1
        else:
            if b_char == ' ':
                ddata['nlosts_space'] += 1
            else:
                ddata['nlosts_specialchars'] += 1
                update_ddata_specialchars(ddata_specialchars, 'lost', b_char)

    else:
        ddata['nlosts_multiple'] += 1
        if len(b_char) > 10:
            ddata['nlong_sequences'] += 1
            # print(b_char)


def inline_diff(a, b, ddata_specialchars={'lost': {}, 'add': {}, 'wrong': {}, 'UL': {}}):
    matcher = difflib.SequenceMatcher(None, a, b)
    ddata = {'res_text': ''}
    # ddata = ddata | {key: 0 for key in ['nsingle', 'nmultiple']}
    ddata = ddata | {key: 0 for key in ['UL_single', 'UL_multiple']}
    ddata = ddata | {
        key: 0 for key in
        ['nlosts', 'nlosts_single', 'nlosts_multiple', 'nlosts_chars', 'nlosts_specialchars', 'nlosts_space']}
    ddata = ddata | {
        key: 0 for key in
        ['nadds', 'nadds_single', 'nadds_multiple', 'nadds_chars', 'nadds_specialchars', 'nadds_space']}
    ddata = ddata | {
        key: 0 for key in
        ['nwrongs', 'nwrongs_single', 'nwrongs_multiple', 'nwrong_accent', 'nlost_accent', 'nwrong_vowels',
         'nwrong_consonants', 'nwrong_specialchars']}
    ddata['nlong_sequences'] = 0
    for tag, i1, i2, j1, j2 in matcher.get_opcodes():
        if tag == 'replace':  # wrong
            process_replace_tag(matcher, i1, i2, j1, j2, ddata, ddata_specialchars)
        if tag == 'delete':  # OCR add char so the matcher "delete"
            process_delete_tag(matcher, i1, i2, ddata, ddata_specialchars)
        if tag == 'equal':
            ddata['res_text'] += matcher.a[i1:i2]
        if tag == 'insert':  # OCR lost char so the matcher "insert"
            process_insert_tag(matcher, j1, j2, ddata, ddata_specialchars)
    ddata["ned"] = ddata['nwrongs'] + ddata['nadds'] + ddata['nlosts']
    return ddata


def process_single_file(file_name, ddata_specialchars):

    # read predict file
    with open(os.path.join(PREDICT_PATH, file_name), 'r') as f:
        predict = f.readlines()[0].strip()
        # predict = ''.join(predict)
        # predict = predict.replace(' ', '')
        # predict = predict.replace('\n', '')
    # print(predict)

    # read groundtruth file
    with open(os.path.join(GROUNDTRUTH_PATH, file_name), 'r') as f:
        gt = f.readlines()[0].strip()
        # gt = ''.join(gt)
        # gt = gt.replace('\n', '')

    # get statiscal data of difference between predict and ground truth
    ddata = inline_diff(predict, gt, ddata_specialchars)
    global TOTAL_WORD
    TOTAL_WORD = TOTAL_WORD + len(gt.split())
    # write to save_path
    res_text = ddata.pop('res_text', None)
    save_file = os.path.join(RES_PATH, file_name)
    with open(save_file, 'w') as f:
        f.write(res_text)

    # generate csv file
    ddata = {'file_name': save_file} | ddata
    return ddata


def main(overwrite=False):
    for accent_file in [WRONG_ACCENT_FILE, LOST_ACCENT_FILE]:
        if os.path.exists(accent_file):
            os.remove(accent_file)
    lddata = []
    ddata_specialchars = {'lost': {}, 'add': {}, 'wrong': {}, 'UL': {}}
    for file_ in glob.glob(f'{PREDICT_PATH}/*.txt'):
        file_name = file_.split('/')[-1]
        ddata = process_single_file(file_name, ddata_specialchars)
        lddata.append(ddata)
    if overwrite:
        df = pd.DataFrame(lddata)
        df.to_csv(f'{SAVE_PATH}/wiki_diff.csv', sep='\t')
        df_ = pd.DataFrame(ddata_specialchars)
        df_.to_csv(f'{SAVE_PATH}/wiki_diff_specialchars.csv', sep='\t')
    print(TOTAL_WORD)


if __name__ == '__main__':
    main(overwrite=True)