from PIL import Image
import cv2

import numpy as np
from transformers import (
    LayoutXLMTokenizer,
    LayoutLMv2FeatureExtractor,
    LayoutXLMProcessor,
    LayoutLMv2ForTokenClassification,
)

from common.utils.word_formation import *

from common.utils.global_variables import *
from common.utils.process_label import *
import ssl

ssl._create_default_https_context = ssl._create_unverified_context
os.environ["CURL_CA_BUNDLE"] = ""
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True


# config
IGNORE_KIE_LABEL = "others"
KIE_LABELS = [
    "Number",
    "Name",
    "Birthday",
    "Home Town",
    "Address",
    "Sex",
    "Nationality",
    "Expiry Date",
    "Nation",
    "Religion",
    "Date Range",
    "Issued By",
    IGNORE_KIE_LABEL,
    "Rank"
]
DEVICE = "cuda:0"

# MAX_SEQ_LENGTH = 512  # TODO Fix this hard code

# tokenizer = LayoutXLMTokenizer.from_pretrained(
#     "Kie_AHung/model/pretrained/layoutxlm-base/tokenizer", model_max_length=MAX_SEQ_LENGTH
# )

# feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
# processor = LayoutXLMProcessor(feature_extractor, tokenizer)

model = LayoutLMv2ForTokenClassification.from_pretrained(
    "Kie_AHung/model/driver_license", num_labels=len(KIE_LABELS), local_files_only=True
).to(
    DEVICE
)  # TODO FIX this hard code


def load_ocr_labels(list_lines):
    words, boxes, labels = [], [], []
    for line in list_lines:
        for word_group in line.list_word_groups:
            for word in word_group.list_words:
                xmin, ymin, xmax, ymax = (
                    word.boundingbox[0],
                    word.boundingbox[1],
                    word.boundingbox[2],
                    word.boundingbox[3],
                )
                text = word.text
                label = "seller_name_value"  # TODO ??? fix this
                x1, y1, x2, y2 = float(xmin), float(ymin), float(xmax), float(ymax)
                if text != " ":
                    words.append(text)
                    boxes.append([x1, y1, x2, y2])
                    labels.append(label)
    return words, boxes, labels


def _normalize_box(box, width, height):
    return [
        int(1000 * (box[0] / width)),
        int(1000 * (box[1] / height)),
        int(1000 * (box[2] / width)),
        int(1000 * (box[3] / height)),
    ]


def infer_driving_license(image_crop, list_lines, max_n_words, processor):
    # Load inputs
    # image = Image.open(image_path)
    image = cv2.cvtColor(image_crop, cv2.COLOR_BGR2RGB)
    image = Image.fromarray(image)
    batch_words, batch_boxes, _ = load_ocr_labels(list_lines)
    batch_preds, batch_true_boxes = [], []
    list_words = []
    for i in range(0, len(batch_words), max_n_words):
        words = batch_words[i : i + max_n_words]
        boxes = batch_boxes[i : i + max_n_words]
        boxes_norm = [
            _normalize_box(bbox, image.size[0], image.size[1]) for bbox in boxes
        ]

        # Preprocess
        dummy_word_labels = [0] * len(words)
        encoding = processor(
            image,
            text=words,
            boxes=boxes_norm,
            word_labels=dummy_word_labels,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=512,
        )

        # Run model
        for k, v in encoding.items():
            encoding[k] = v.to(DEVICE)
        outputs = model(**encoding)
        predictions = outputs.logits.argmax(-1).squeeze().tolist()

        # Postprocess
        is_subword = (
            (encoding["labels"] == -100).detach().cpu().numpy()[0]
        )  # remove padding
        true_predictions = [
            pred for idx, pred in enumerate(predictions) if not is_subword[idx]
        ]
        true_boxes = (
            boxes  # TODO check assumption that layourlm do not change box order
        )

        for i, word in enumerate(words):
            bndbox = [int(j) for j in true_boxes[i]]
            list_words.append(
                Word(
                    text=word, bndbox=bndbox, kie_label=KIE_LABELS[true_predictions[i]]
                )
            )

        batch_preds.extend(true_predictions)
        batch_true_boxes.extend(true_boxes)

    batch_preds = np.array(batch_preds)
    batch_true_boxes = np.array(batch_true_boxes)
    return batch_words, batch_preds, batch_true_boxes, list_words