import os
import yaml
from pathlib import Path
from PIL import Image
from io import BytesIO
import numpy as np
import torch
from sdsvtd import StandaloneYOLOXRunner

from common.utils.word_formation import Word, words_to_lines

def read_imagefile(file) -> Image.Image:
    image = Image.open(BytesIO(file))
    return image

def sort_bboxes(lbboxes)->tuple[list, list]:
    lWords = [Word(bndbox = bbox) for bbox in lbboxes]
    list_lines, _ = words_to_lines(lWords)
    lbboxes_ = list()
    for line in list_lines:
        for word_group in line.list_word_groups:
            for word in word_group.list_words:
                lbboxes_.append(word.boundingbox)
    return lbboxes_

class Predictor:
    def __init__(self, setting_file='./setting.yml'):
        with open(setting_file) as f:
            # use safe_load instead load
            self.setting = yaml.safe_load(f)

        base_path = Path(__file__).parent
        model_config_path = os.path.join(base_path, '../' , self.setting['model_config'])
        self.mode = self.setting['mode']
        device = self.setting['device']
        
        if self.mode == 'trt':
            import sys
            sys.path.append(self.setting['mmdeploy_path'])
            from mmdeploy.utils import get_input_shape, load_config
            from mmdeploy.apis.utils import build_task_processor
            
            deploy_config_path = os.path.join(base_path, '../' , self.setting['deploy_config'])

            class TensorRTInfer:
                def __init__(self, deploy_config_path, model_config_path, checkpoint_path, device='cuda:0'):
                    deploy_cfg, model_cfg = load_config(deploy_config_path, model_config_path)
                    self.task_processor = build_task_processor(model_cfg, deploy_cfg, device)
                    self.model = self.task_processor.init_backend_model([checkpoint_path])
                    self.input_shape = get_input_shape(deploy_cfg)

                def __call__(self, images):
                    model_input, _ = self.task_processor.create_input(images, self.input_shape)
                    with torch.no_grad():
                        results = self.model(return_loss=False, rescale=True, **model_input)
                    return results
                
            checkpoint_path = self.setting['checkpoint']
            self.trt_infer = TensorRTInfer(deploy_config_path, model_config_path, checkpoint_path, device=device)
        elif self.mode == 'torch':
            self.runner = StandaloneYOLOXRunner(version=self.setting['model_config'], device=device)
        else:
            raise ValueError('No such inference mode')
        
    def __call__(self, images):
        if self.mode == 'torch':
            result = []
            for image in images:
                result.append(self.runner(image))
        elif self.mode == 'tensorrt':
            result = self.trt_infer(images)
        
        sorted_result = []
        for res, image in zip(result, images):
            h, w = image.shape[:2]
            res = res[0][:, :4] # leave out confidence score

            # clip inside image range
            res[:, 0] = np.clip(res[:, 0], a_min=0, a_max=w)
            res[:, 2] = np.clip(res[:, 2], a_min=0, a_max=w)
            res[:, 1] = np.clip(res[:, 1], a_min=0, a_max=h)
            res[:, 3] = np.clip(res[:, 3], a_min=0, a_max=h)

            res = res.astype(int).tolist()
            res = sort_bboxes(res)
            sorted_result.append(res)
        return sorted_result