import nltk
import re
import string
import copy
from utils.kvu_dictionary import vat_dictionary, ap_dictionary, DKVU2XML
nltk.download('words')
words = set(nltk.corpus.words.words())
import logging
import logging.config
from utils.logging.logging import LOGGER_CONFIG
# Load the logging configuration
logging.config.dictConfig(LOGGER_CONFIG)
# Get the logger
logger = logging.getLogger(__name__)

s1 = u'ÀÁÂÃÈÉÊÌÍÒÓÔÕÙÚÝàáâãèéêìíòóôõùúýĂăĐđĨĩŨũƠơƯưẠạẢảẤấẦầẨẩẪẫẬậẮắẰằẲẳẴẵẶặẸẹẺẻẼẽẾếỀềỂểỄễỆệỈỉỊịỌọỎỏỐốỒồỔổỖỗỘộỚớỜờỞởỠỡỢợỤụỦủỨứỪừỬửỮữỰựỲỳỴỵỶỷỸỹ'
s0 = u'AAAAEEEIIOOOOUUYaaaaeeeiioooouuyAaDdIiUuOoUuAaAaAaAaAaAaAaAaAaAaAaAaEeEeEeEeEeEeEeEeIiIiOoOoOoOoOoOoOoOoOoOoOoOoUuUuUuUuUuUuUuYyYyYyYy'

# def clean_text(text):
#     return re.sub(r"[^A-Za-z(),!?\'\`]", " ", text)


def get_string(lwords: list):
    unique_list = []
    for item in lwords:
        if item.isdigit() and len(item) == 1:
            unique_list.append(item)
        elif item not in unique_list:
            unique_list.append(item)
    return ' '.join(unique_list)

def remove_english_words(text):
    _word = [w.lower() for w in nltk.wordpunct_tokenize(text) if w.lower() not in words]
    return ' '.join(_word)

def remove_punctuation(text):
    return text.translate(str.maketrans(" ", " ", string.punctuation))

def remove_accents(input_str, s0, s1):
	s = ''
	# logger.info input_str.encode('utf-8')
	for c in input_str:
		if c in s1:
			s += s0[s1.index(c)]
		else:
			s += c
	return s

def remove_spaces(text):
    return text.replace(' ', '')

def preprocessing(text: str):
    # text = remove_english_words(text) if table else text
    text = remove_punctuation(text)
    text = remove_accents(text, s0, s1)
    text = remove_spaces(text)
    return text.lower()


def vat_standardize_outputs(vat_outputs: dict) -> dict:
    outputs = {}
    for key, value in vat_outputs.items():
        if key != "table":
            outputs[DKVU2XML[key]] = value
        else:
            list_items = []
            for item in value:
                list_items.append({
                        DKVU2XML[item_key]: item_value for item_key, item_value in item.items()
                    })
            outputs['table'] = list_items
    return outputs
    
                

def vat_standardizer(text: str, threshold: float, header: bool):
    dictionary = vat_dictionary(header)
    processed_text = preprocessing(text)
    
    for candidates in [('ngayday', 'ngaydate', 'ngay', 'day'), ('thangmonth', 'thang', 'month'), ('namyear', 'nam', 'year')]:
        if any([processed_text in txt for txt in candidates]):
            processed_text = candidates[-1]
            return "Ngày, tháng, năm lập hóa đơn", 5, processed_text
    
    _dictionary = copy.deepcopy(dictionary)
    if not header:
        exact_dictionary = {
            'Số hóa đơn': ['sono', 'so'],
            'Mã số thuế người bán': ['mst'],
            'Tên người bán': ['kyboi'],
            'Ngày, tháng, năm lập hóa đơn': ['kyngay', 'kyngaydate']
        }
        for k, v in exact_dictionary.items():
            _dictionary[k] = dictionary[k] + exact_dictionary[k]
            
    for k, v in dictionary.items():
        # if k in ("Ngày, tháng, năm lập hóa đơn"):
        #     continue    
        # Prioritize match completely
        if k in ('Tên người bán') and processed_text == "kyboi":
            return k, 8, processed_text
        
        if any([processed_text == key for key in _dictionary[k]]):
            return k, 10, processed_text 
        
    scores = {k: 0.0 for k in dictionary}
    for k, v in dictionary.items():
        if k in ("Ngày, tháng, năm lập hóa đơn"):
            continue
        
        scores[k] = max([longestCommonSubsequence(processed_text, key)/len(key) for key in dictionary[k]])
    
    key, score = max(scores.items(), key=lambda x: x[1])
    return key if score > threshold else text, score, processed_text

def ap_standardizer(text: str, threshold: float, header: bool):
    dictionary = ap_dictionary(header)
    processed_text = preprocessing(text)
    
    # Prioritize match completely
    _dictionary = copy.deepcopy(dictionary)
    if not header:
        _dictionary['serial_number'] = dictionary['serial_number'] + ['sn']
        _dictionary['imei_number'] = dictionary['imei_number'] + ['imel']
    else:
        _dictionary['modelnumber'] = dictionary['modelnumber'] + ['sku', 'sn', 'imei']
        _dictionary['qty'] = dictionary['qty'] + ['qty'] 
        
    for k, v in dictionary.items():
        if any([processed_text == key for key in _dictionary[k]]):
            return k, 10, processed_text 
    
    scores = {k: 0.0 for k in dictionary}
    for k, v in dictionary.items():
        scores[k] = max([longestCommonSubsequence(processed_text, key)/len(key) for key in dictionary[k]])
    
    key, score = max(scores.items(), key=lambda x: x[1])
    return key if score >= threshold else text, score, processed_text


def convert_format_number(s: str) -> float:
    s = s.replace(' ', '').replace('O', '0').replace('o', '0')
    if s.endswith(",00") or s.endswith(".00"):
        s = s[:-3]
    if all([delimiter in s for delimiter in [',', '.']]):
        s = s.replace('.', '').split(',')
        remain_value = s[1].split('0')[0]
        return int(s[0]) + int(remain_value) * 1 / (10**len(remain_value))
    else:
        s = s.replace(',', '').replace('.', '')
        return int(s)


def post_process_for_item(item: dict) -> dict:
    check_keys = ['Số lượng', 'Đơn giá', 'Doanh số mua chưa có thuế']
    mis_key = []
    for key in check_keys:
        if item[key] in (None, '0'):
            mis_key.append(key)
    if len(mis_key) == 1:
        try:
            if mis_key[0] == check_keys[0] and convert_format_number(item[check_keys[1]]) != 0:
                item[mis_key[0]] = round(convert_format_number(item[check_keys[2]]) / convert_format_number(item[check_keys[1]])).__str__()
            elif mis_key[0] == check_keys[1] and convert_format_number(item[check_keys[0]]) != 0:
                item[mis_key[0]] = (convert_format_number(item[check_keys[2]]) / convert_format_number(item[check_keys[0]])).__str__()
            elif mis_key[0] == check_keys[2]:
                item[mis_key[0]] = (convert_format_number(item[check_keys[0]]) * convert_format_number(item[check_keys[1]])).__str__()
        except Exception as e:
            logger.info("Cannot post process this item with error:", e)
    return item


def longestCommonSubsequence(text1: str, text2: str) -> int:
    # https://leetcode.com/problems/longest-common-subsequence/discuss/351689/JavaPython-3-Two-DP-codes-of-O(mn)-and-O(min(m-n))-spaces-w-picture-and-analysis
    dp = [[0] * (len(text2) + 1) for _ in range(len(text1) + 1)]
    for i, c in enumerate(text1):
        for j, d in enumerate(text2):
            dp[i + 1][j + 1] = 1 + \
                dp[i][j] if c == d else max(dp[i][j + 1], dp[i + 1][j])
    return dp[-1][-1]


def longest_common_subsequence_with_idx(X, Y):
    """
    This implementation uses dynamic programming to calculate the length of the LCS, and uses a path array to keep track of the characters in the LCS.
    The longest_common_subsequence function takes two strings as input, and returns a tuple with three values:
    the length of the LCS,
    the index of the first character of the LCS in the first string,
    and the index of the last character of the LCS in the first string.
    """
    m, n = len(X), len(Y)
    L = [[0 for i in range(n + 1)] for j in range(m + 1)]

    # Following steps build L[m+1][n+1] in bottom up fashion. Note
    # that L[i][j] contains length of LCS of X[0..i-1] and Y[0..j-1]
    right_idx = 0
    max_lcs = 0
    for i in range(m + 1):
        for j in range(n + 1):
            if i == 0 or j == 0:
                L[i][j] = 0
            elif X[i - 1] == Y[j - 1]:
                L[i][j] = L[i - 1][j - 1] + 1
                if L[i][j] > max_lcs:
                    max_lcs = L[i][j]
                    right_idx = i
            else:
                L[i][j] = max(L[i - 1][j], L[i][j - 1])

        # Create a string variable to store the lcs string
    lcs = L[i][j]
    # Start from the right-most-bottom-most corner and
    # one by one store characters in lcs[]
    i = m
    j = n
    # right_idx = 0
    while i > 0 and j > 0:
        # If current character in X[] and Y are same, then
        # current character is part of LCS
        if X[i - 1] == Y[j - 1]:

            i -= 1
            j -= 1
        # If not same, then find the larger of two and
        # go in the direction of larger value
        elif L[i - 1][j] > L[i][j - 1]:
            # right_idx = i if not right_idx else right_idx #the first change in L should be the right index of the lcs
            i -= 1
        else:
            j -= 1
    return lcs, i, max(i + lcs, right_idx)