from torch import nn
import torch
from transformers import (
    LayoutXLMTokenizer,
    LayoutLMv2FeatureExtractor,
    LayoutXLMProcessor,
    LayoutLMv2ForTokenClassification,
)


class PositionalEncoding(nn.Module):
    """Positional encoding."""

    def __init__(self, num_hiddens, max_len=10000):
        super(PositionalEncoding, self).__init__()
        # Create a long enough `P`
        self.num_hiddens = num_hiddens

    def forward(self, inputs):
        max_len = inputs.shape[1]
        P = torch.zeros((1, max_len, self.num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(
            10000,
            torch.arange(0, self.num_hiddens, 2, dtype=torch.float32)
            / self.num_hiddens,
        )
        P[:, :, 0::2] = torch.sin(X)
        P[:, :, 1::2] = torch.cos(X)
        return P.to(inputs.device)


def load_layoutlmv2_custom_model(
    weight_dir: str, tokenizer_dir: str, max_seq_len: int, classes: list
):

    model, processor = load_layoutlmv2(tokenizer_dir, weight_dir, max_seq_len, classes)
    # fix for longer lenght
    model.layoutlmv2.embeddings.position_embeddings = PositionalEncoding(
        num_hiddens=768, max_len=max_seq_len
    )
    model.layoutlmv2.embeddings.max_position_embeddings = max_seq_len
    model.config.max_position_embeddings = max_seq_len
    model.layoutlmv2.embeddings.register_buffer(
        "position_ids", torch.arange(max_seq_len).expand((1, -1))
    )

    return model, processor


def load_layoutlmv2(
    weight_dir: str, tokenizer_dir: str, max_seq_len: int, classes: list
):
    tokenizer = LayoutXLMTokenizer.from_pretrained(
        pretrained_model_name_or_path=tokenizer_dir, model_max_length=max_seq_len
    )
    feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
    processor = LayoutXLMProcessor(feature_extractor, tokenizer)
    model = LayoutLMv2ForTokenClassification.from_pretrained(
        weight_dir, num_labels=len(classes)
    )
    return model, processor