import os
import glob
import cv2
import argparse
from tqdm import tqdm
from datetime import datetime
# from omegaconf import OmegaConf
import sys
sys.path.append('/home/thucpd/thucpd/git/PV2-2023/common/AnyKey_Value') # TODO: ????
from predictor import KVUPredictor
from preprocess import KVUProcess, DocumentKVUProcess
from utils.utils import create_dir, visualize, get_colormap, export_kvu_for_VAT_invoice, export_kvu_outputs


def get_args():
    args = argparse.ArgumentParser(description='Main file')
    args.add_argument('--img_dir', default='/home/ai-core/Kie_Invoice_AP/AnyKey_Value/visualize/test/', type=str,
                      help='Input image directory')
    args.add_argument('--save_dir', default='/home/ai-core/Kie_Invoice_AP/AnyKey_Value/visualize/test/', type=str,
                      help='Save directory')
    args.add_argument('--exp_dir', default='/home/thucpd/thucpd/PV2-2023/common/AnyKey_Value/experiments/key_value_understanding-20230608-171900', type=str,
                      help='Checkpoint and config of model')
    args.add_argument('--export_img', default=0, type=int,
                      help='Save visualize on image')
    args.add_argument('--mode', default=3, type=int,
                      help="0:'normal' - 1:'full_tokens' - 2:'sliding' - 3: 'document'")
    args.add_argument('--dir_level', default=0, type=int,
                      help='Number of subfolders contains image')
    
    return args.parse_args()


def load_engine(exp_dir: str, class_names: list, mode: int) -> KVUPredictor:
    configs = {
        'cfg': glob.glob(f'{exp_dir}/*.yaml')[0],
        'ckpt': f'{exp_dir}/checkpoints/best_model.pth'
    }
    dummy_idx = 512
    predictor = KVUPredictor(configs, class_names, dummy_idx, mode)
    
    # processor = KVUProcess(predictor.net.tokenizer_layoutxlm,
    #                         predictor.net.feature_extractor, predictor.backbone_type, class_names,
    #                         predictor.slice_interval, predictor.window_size, run_ocr=1, mode=mode) 
    
    processor = DocumentKVUProcess(predictor.net.tokenizer, predictor.net.feature_extractor, 
                           predictor.backbone_type, class_names,
                           predictor.max_window_count, predictor.slice_interval, predictor.window_size, 
                           run_ocr=1, mode=mode) 
    return predictor, processor


def predict_image(img_path: str, save_dir: str, predictor: KVUPredictor, processor) -> None:
    fname = os.path.basename(img_path)
    img_ext = os.path.splitext(img_path)[1]
    output_ext = ".json"
    inputs = processor(img_path, ocr_path='')

    bbox, lwords, pr_class_words, pr_relations = predictor.predict(inputs)
    
    slide_window = False if len(bbox) == 1 else True
    
    if len(bbox) == 0:
        vat_outputs = export_kvu_for_VAT_invoice(os.path.join(save_dir, fname.replace(img_ext, output_ext)), lwords, pr_class_words, pr_relations, predictor.class_names)
    else:
        for i in range(len(bbox)):
            if not slide_window:
                save_path = os.path.join(save_dir, 'kvu_results')
                create_dir(save_path)
                # export_kvu_for_SDSAP(os.path.join(save_dir, fname.replace(img_ext, output_ext)), lwords[i], pr_class_words[i], pr_relations[i], predictor.class_names)
                vat_outputs =  export_kvu_for_VAT_invoice(os.path.join(save_dir, fname.replace(img_ext, output_ext)), lwords[i], pr_class_words[i], pr_relations[i], predictor.class_names)
    
    return vat_outputs
        

def Predictor_KVU(img: str, save_dir: str, predictor: KVUPredictor, processor) -> None:
    
    # req = urllib.request.urlopen(image_url)
    # arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
    # img = cv2.imdecode(arr, -1)
    curr_datetime = datetime.now().strftime('%Y-%m-%d %H-%M-%S')
    image_path = "/home/thucpd/thucpd/PV2-2023/tmp_image/{}.jpg".format(curr_datetime)
    cv2.imwrite(image_path, img)
    vat_outputs = predict_image(image_path, save_dir, predictor, processor)
    return vat_outputs


if __name__ == "__main__":
    args = get_args()
    class_names = ['others', 'title', 'key', 'value', 'header']
    predict_mode = {
        'normal': 0,
        'full_tokens': 1,
        'sliding': 2,
        'document': 3
    }
    predictor, processor = load_engine(args.exp_dir, class_names, args.mode)
    create_dir(args.save_dir)
    image_path = "/mnt/ssd1T/tuanlv/PV2-2023/common/AnyKey_Value/visualize/test1/RedInvoice_WaterPurfier_Feb_PVI_829_0.jpg"
    save_dir = "/mnt/ssd1T/tuanlv/PV2-2023/common/AnyKey_Value/visualize/test1"
    vat_outputs = predict_image(image_path, save_dir, predictor, processor)
    print('[INFO] Done')