import cv2
from common.ocr import ocr_predict
from common.crop_location import crop_location
from Kie_AHung.prediction import infer_driving_license
from Kie_AHung_ID.prediction import infer_id_card
from common.post_processing_datetime import DatetimeCorrector
from transformers import (
    LayoutXLMTokenizer,
    LayoutLMv2FeatureExtractor,
    LayoutXLMProcessor
    )

MAX_SEQ_LENGTH = 512  # TODO Fix this hard code
tokenizer = LayoutXLMTokenizer.from_pretrained(
    "./Kie_AHung/model/pretrained/layoutxlm-base/tokenizer", model_max_length=MAX_SEQ_LENGTH
)

feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
processor = LayoutXLMProcessor(feature_extractor, tokenizer)

max_n_words = 100

def predict(rq_id, sub_id, folder_name, list_url, user_id, infer_name):
    """Predict text from image

    Args:
        image_path (str): path to image

    Returns:
        dict: dict result of prediction
    """
    
    results = {
                "model":{
                    "name":infer_name,
                    "confidence": 1.0,
                    "type": "finance/invoice",
                    "isValid": True,
                    "shape": "letter",
                }
            }
    compile_outputs = []
    for page in list_url:
        image_location = crop_location(page['file_url'])
        if image_location is None:
            compile_output = {
                        'page_index': page['page_number'],
                        'path_image_croped': None,
                        'request_file_id': page['request_file_id'],
                        'fields': None
                    }
            compile_outputs.append(compile_output)
            
        elif image_location is not None:
            path_image_croped = "/app/media/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id)
            cv2.imwrite("/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id), image_location)
            list_line = ocr_predict(image_location)
            
            if infer_name == "driving_license":
                from common.post_processing_driver import create_kie_dict
                _, _, _, list_words = infer_driving_license(image_location, list_line, max_n_words, processor)
                result_dict = create_kie_dict(list_words)
            elif infer_name == "id_card":
                from common.post_processing_id import create_kie_dict
                _, _, _, list_words = infer_id_card(image_location, list_line, max_n_words, processor)
                result_dict = create_kie_dict(list_words)

            fields = []
            for kie_label in result_dict:
                if result_dict[kie_label]["text"] != "":
                    if kie_label == "Date Range":
                        text = DatetimeCorrector.correct(result_dict[kie_label]["text"])
                    else:
                        text = result_dict[kie_label]["text"]
                    
                    field = {
                        "label": kie_label,
                        "value": text.replace("✪", " ") if "✪" in text else text,
                        "box": result_dict[kie_label]["bbox"],
                        "confidence": 0.99 #TODO: add confidence
                    }
                    fields.append(field)
                    
            compile_output = {
                'page_index': page['page_number'],
                'path_image_croped': str(path_image_croped),
                'request_file_id': page['request_file_id'],
                'fields': fields
            }
                    
            compile_outputs.append(compile_output)
    results['pages'] = compile_outputs
    return results