from sdsvkie import Predictor
import cv2
import numpy as np
import urllib
from common import serve_model
from common import ocr

model = Predictor(
        cfg = "./models/kie_invoice/config.yaml",
        device = "cuda:0",
        weights = "./models/models/kie_invoice/last",
        proccessor = serve_model.processor,
        ocr_engine = ocr.engine
    )

def predict(page_numb, image_url):
    """
    module predict function

    Args:
        image_url (str): image url

    Returns:
    example output:
         "data": {
            "document_type": "invoice",
            "fields": [
            {
                "label": "Invoice Number",
                "value": "INV-12345",
                "box": [0, 0, 0, 0],
                "confidence": 0.98
            },
            ...
            ]
        }
    dict: output of model
    """
    req = urllib.request.urlopen(image_url)
    arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
    image = cv2.imdecode(arr, -1)
    out = model(image)
    output = out["end2end_results"]
    output_dict = {
        "document_type": "invoice",
        "fields": []
    }
    for key in output.keys():
        field = {
            "label": key,
            "value": output[key]['value'] if output[key]['value'] else "",
            "box": output[key]['box'],
            "confidence": output[key]['conf'],
            "page": page_numb
        }
        output_dict['fields'].append(field)
    return output_dict