from omegaconf import OmegaConf
import torch
# from functions import get_colormap, visualize
import sys
sys.path.append('/mnt/ssd1T/tuanlv/02.KeyValueUnderstanding/') #TODO: ??????

from lightning_modules.classifier_module import parse_initial_words, parse_subsequent_words, parse_relations
from model import get_model
from utils import load_model_weight


class KVUPredictor:
    def __init__(self, configs, class_names, dummy_idx, mode=0):
        cfg_path = configs['cfg']
        ckpt_path = configs['ckpt']
        
        self.class_names = class_names
        self.dummy_idx = dummy_idx
        self.mode = mode
    
        print('[INFO] Loading Key-Value Understanding model ...')
        self.net, cfg, self.backbone_type = self._load_model(cfg_path, ckpt_path)
        print("[INFO] Loaded model")
        
        if mode == 3:
            self.max_window_count = cfg.train.max_window_count
            self.window_size = cfg.train.window_size
            self.slice_interval = 0
            self.dummy_idx = dummy_idx * self.max_window_count
        else:
            self.slice_interval = cfg.train.slice_interval
            self.window_size = cfg.train.max_num_words
            
        
        self.device = 'cuda'
        
    def _load_model(self, cfg_path, ckpt_path):
        cfg = OmegaConf.load(cfg_path)
        cfg.stage = self.mode
        backbone_type = cfg.model.backbone
        
        print('[INFO] Checkpoint:', ckpt_path)
        net = get_model(cfg)
        load_model_weight(net, ckpt_path)
        net.to('cuda')
        net.eval()
        return net, cfg, backbone_type
    
    def predict(self, input_sample):
        if self.mode == 0:
            if len(input_sample['words']) == 0:
                return [], [], [], []
            bbox, lwords, pr_class_words, pr_relations = self.combined_predict(input_sample)
            return [bbox], [lwords], [pr_class_words], [pr_relations]
        
        elif self.mode == 1:
            if len(input_sample['documents']['words']) == 0:
                return [], [], [], []
            bbox, lwords, pr_class_words, pr_relations = self.cat_predict(input_sample)
            return [bbox], [lwords], [pr_class_words], [pr_relations]
        
        elif self.mode == 2:
            if len(input_sample['windows'][0]['words']) == 0:
                return [], [], [], []
            bbox, lwords, pr_class_words, pr_relations = [], [], [], []
            for window in input_sample['windows']:
                _bbox, _lwords, _pr_class_words, _pr_relations = self.combined_predict(window)
                bbox.append(_bbox)
                lwords.append(_lwords)
                pr_class_words.append(_pr_class_words)
                pr_relations.append(_pr_relations) 
            return bbox, lwords, pr_class_words, pr_relations
        
        elif self.mode == 3:
            if len(input_sample["documents"]['words']) == 0:
                return [], [], [], []
            bbox, lwords, pr_class_words, pr_relations = self.doc_predict(input_sample)
            return [bbox], [lwords], [pr_class_words], [pr_relations]
        
        else:
            raise ValueError(
                f"Not supported mode: {self.mode}"
            )
            
    def doc_predict(self, input_sample):
        lwords = input_sample['documents']['words']
        for idx, window in enumerate(input_sample['windows']):
            input_sample['windows'][idx] = {k: v.unsqueeze(0).to(self.device) for k, v in window.items() if k not in ('words', 'n_empty_windows')}
        
        # input_sample['documents'] = {k: v.unsqueeze(0).to(self.device) for k, v in input_sample['documents'].items() if k not in ('words', 'n_empty_windows')}
        with torch.no_grad():
            head_outputs, _ = self.net(input_sample)

        head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()}
        input_sample = input_sample['documents']
        
        itc_outputs = head_outputs["itc_outputs"]
        stc_outputs = head_outputs["stc_outputs"]
        el_outputs = head_outputs["el_outputs"]
        el_outputs_from_key = head_outputs["el_outputs_from_key"]

        pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
        pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
        pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
        pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)

        box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0)
        attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0)  
        bbox = input_sample['bbox'].squeeze(0)

        pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
        pr_class_words = parse_subsequent_words(
            pr_stc_label, attention_mask, pr_init_words, self.dummy_idx
        )

        pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx)
        pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx)
        pr_relations = pr_relations_from_header | pr_relations_from_key

        return bbox, lwords, pr_class_words, pr_relations
        
    
    def combined_predict(self, input_sample):
        lwords = input_sample['words']
        input_sample = {k: v.unsqueeze(0) for k, v in input_sample.items() if k not in ('words', 'img_path')}
        
        input_sample = {k: v.to(self.device) for k, v in input_sample.items()}
        

        with torch.no_grad():
            head_outputs, _ = self.net(input_sample)

        head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()}
        input_sample = {k: v.detach().cpu() for k, v in input_sample.items()}
        

        itc_outputs = head_outputs["itc_outputs"]
        stc_outputs = head_outputs["stc_outputs"]
        el_outputs = head_outputs["el_outputs"]
        el_outputs_from_key = head_outputs["el_outputs_from_key"]

        pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
        pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
        pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
        pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)

        box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0)
        attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0)  
        bbox = input_sample['bbox'].squeeze(0)

        pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
        pr_class_words = parse_subsequent_words(
            pr_stc_label, attention_mask, pr_init_words, self.dummy_idx
        )

        pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx)
        pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx)
        pr_relations = pr_relations_from_header | pr_relations_from_key

        return bbox, lwords, pr_class_words, pr_relations
    
    def cat_predict(self, input_sample):
        lwords = input_sample['documents']['words']
        
        inputs = []
        for window in input_sample['windows']:
            inputs.append({k: v.unsqueeze(0).cuda() for k, v in window.items() if k not in ('words', 'img_path')})
        input_sample['windows'] = inputs
        
        with torch.no_grad():
            head_outputs, _ = self.net(input_sample)
            
        head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items() if k not in ('embedding_tokens')}        

        itc_outputs = head_outputs["itc_outputs"]
        stc_outputs = head_outputs["stc_outputs"]
        el_outputs = head_outputs["el_outputs"]
        el_outputs_from_key = head_outputs["el_outputs_from_key"]
        
        pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
        pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
        pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
        pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)

        box_first_token_mask = input_sample['documents']['are_box_first_tokens']
        attention_mask = input_sample['documents']['attention_mask_layoutxlm'] 
        bbox = input_sample['documents']['bbox']

        dummy_idx = input_sample['documents']['bbox'].shape[0]


        pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
        pr_class_words = parse_subsequent_words(
            pr_stc_label, attention_mask, pr_init_words, dummy_idx
        )

        pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, dummy_idx)
        pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, dummy_idx)
        pr_relations = pr_relations_from_header | pr_relations_from_key

        return bbox, lwords, pr_class_words, pr_relations
    
    
    def get_ground_truth_label(self, ground_truth):
        # ground_truth = self.preprocessor.load_ground_truth(json_file)
        gt_itc_label = ground_truth['itc_labels'].squeeze(0) # [1, 512] => [512]
        gt_stc_label = ground_truth['stc_labels'].squeeze(0) # [1, 512] => [512]
        gt_el_label = ground_truth['el_labels'].squeeze(0)
   
        gt_el_label_from_key = ground_truth['el_labels_from_key'].squeeze(0)
        lwords = ground_truth["words"]
        
        box_first_token_mask = ground_truth['are_box_first_tokens'].squeeze(0)
        attention_mask = ground_truth['attention_mask'].squeeze(0)   
    
        bbox = ground_truth['bbox'].squeeze(0)
        gt_first_words = parse_initial_words(
            gt_itc_label, box_first_token_mask, self.class_names
        )
        gt_class_words = parse_subsequent_words(
            gt_stc_label, attention_mask, gt_first_words, self.dummy_idx
        )
        
        gt_relations_from_header = parse_relations(gt_el_label, box_first_token_mask, self.dummy_idx)
        gt_relations_from_key = parse_relations(gt_el_label_from_key, box_first_token_mask, self.dummy_idx)
        gt_relations = gt_relations_from_header | gt_relations_from_key
        
        return bbox, lwords, gt_class_words, gt_relations