import requests
import time
import argparse
import multiprocessing
import tqdm
import random
import traceback


parser = argparse.ArgumentParser()
parser.add_argument("--host", dest="host", default="https://sbt.idp.sdsrv.ai", required=False)
parser.add_argument("-u", "--username", help="Username to connect to server", required=True)
parser.add_argument("-p", "--password", help="Password to connect to server", required=True)
parser.add_argument("--num_requests", type=int, help="Number of requests", required=False, default=100)
parser.add_argument("--num_workers", type=int, help="Number of workers", required=False, default=3)
parser.add_argument("--checking_interval", type=float, help="Interval result checking time", required=False, default=0.5)
args = parser.parse_args()

PROCESSING_TIMEOUT = 60


# =================================================================
# GET THE TOKEN
response = requests.post(f'{args.host}/api/ctel/login/', json={
    'username': args.username,
    'password': args.password
})
try:
    token = response.json()['token']
except:
    print("Failed to login")
    print(response.content)
# After the login, store the token in the memory (RAM) or DB
# Re-login to issue a new token after 6 days.
# =================================================================

def process_file(data):
    files, token = data
    num_files = len(files)
    files.append(
        ('processType', (None, 12)),
    )
    # =================================================================
    # UPLOAD THE FILE
    start_time = time.time()
    try:
        response = requests.post(f'{args.host}/api/ctel/images/process/', headers={
            'Authorization': token,
        }, files=files, timeout=100)
    except requests.exceptions.Timeout:
        print("Timeout occurred while uploading")
        return {
            "success": False,
            "status": "timeout",
            "upload_time": 0,
            "process_time": 0,
            "num_files": 0,
        }
    except Exception as e:
        print(e)
        traceback.print_exc()
        print("Unknown exception occurred while uploading")
        return {
            "success": False,
            "status": "unknown error",
            "upload_time": 0,
            "process_time": 0,
            "num_files": 0,
        }
    data = response.json()
    if "request_id" not in data:
        print("Missing request_id")
        print(data)
        return {
            "success": False,
            "status": "unknown error",
            "upload_time": 0,
            "process_time": 0,
            "num_files": 0,
        }
    request_id = response.json()['request_id']
    end_time = time.time()
    upload_time = end_time - start_time
    # =================================================================

    # =================================================================
    # CHECK THE RESULT
    start_time = time.time()
    while True:
        try:
            response = requests.get(f'{args.host}/api/ctel/result/{request_id}/', headers={
                'Authorization': token,
            }, timeout=100)
        except requests.exceptions.Timeout:
            print("Timeout occurred while requerying result")
            return {
                "success": False,
                "status": "timeout",
                "upload_time": 0,
                "process_time": 0,
                "num_files": 0,
            }
        except Exception as e:
            print(e)
            traceback.print_exc()
            print("Unknown exception occurred while uploading")
            return {
                "success": False,
                "status": "unknown error",
                "upload_time": 0,
                "process_time": 0,
                "num_files": 0,
            }
        data = response.json()
        if data.get("data", None):
            print(data.get("data")) # Got the response
            if data.get("data", None).get("status", 200) != 200:
                return {
                    "success": False,
                    "status": data.get("status", -1),
                    "upload_time": 0,
                    "process_time": 0,
                    "num_files": 0,
                }
            break
        else:
            if time.time() - start_time > PROCESSING_TIMEOUT:
                print("Timeout!")
                return {
                    "success": False,
                    "status": "timeout",
                    "upload_time": 0,
                    "process_time": 0,
                    "num_files": 0,
                }
            time.sleep(args.checking_interval)
    end_time = time.time()
    process_time = end_time - start_time
    # =================================================================
    return {
        "success": True,
        "status": 200,
        "upload_time": upload_time,
        "process_time": process_time,
        "num_files": num_files,
    }

invoice_files = [
    ('invoice_file', ('invoice.jpg', open("test_samples/sbt/invoice.jpg", "rb").read())),
]
imei_files = [
    ('imei_files', ("test_samples/sbt/imei1.jpg", open("test_samples/sbt/invoice.jpg", "rb").read())),
    ('imei_files', ("test_samples/sbt/imei2.jpg", open("test_samples/sbt/imei2.jpg", "rb").read())),
    ('imei_files', ("test_samples/sbt/imei3.jpg", open("test_samples/sbt/imei3.jpg", "rb").read())),
    ('imei_files', ("test_samples/sbt/imei4.jpeg", open("test_samples/sbt/imei4.jpeg", "rb").read())),
    ('imei_files', ("test_samples/sbt/imei5.jpg", open("test_samples/sbt/imei5.jpg", "rb").read())),
]
def get_imei_files():
    num_files = random.randint(1, len(imei_files) + 1)
    files = imei_files[:num_files]
    # print("Num of imei files:", len(files))
    return files
def get_files():
    return invoice_files + get_imei_files()
def gen_input(num_input):
    for _ in range(num_input):
        yield (get_files(), token)
pool = multiprocessing.Pool(processes=args.num_workers)
results = []
for result in tqdm.tqdm(pool.imap_unordered(process_file, gen_input(num_input=args.num_requests)), total=args.num_requests):
    results.append(result)

print("## TEST REPORT #################################")
print("Number of requests: {}".format(args.num_requests))
print("Number of concurrent requests: {}".format(args.num_workers))
print("Number of files: 1 invoice, 1-5 imei files (random)")
print("Query time interval for result: {:.3f}s ".format(args.checking_interval))
print("--------------------------------------")
print("SUCCESS RATE")
counter = {}
for result in results:
    counter[result["status"]] = counter.get(result["status"], 0) + 1
total_requests = sum(counter.values())
print("Success rate: {}".format(counter.get(200, 0) / total_requests if total_requests > 0 else -1))
print("Statuses:", counter)
print("--------------------------------------")
print("TIME BY REQUEST")
uploading_time = [x["upload_time"] for x in results if x["success"]]
if len(uploading_time) == 0:
    print("No valid uploading time")
    print("Check the results!")
processing_time = [x["process_time"] for x in results if x["success"]]
print("Uploading time (Avg / Min / Max): {:.3f}s {:.3f}s {:.3f}s".format(sum(uploading_time) / len(uploading_time), min(uploading_time), max(uploading_time)))
print("Processing time (Avg / Min / Max): {:.3f}s {:.3f}s {:.3f}s".format(sum(processing_time) / len(processing_time), min(processing_time), max(processing_time)))
print("--------------------------------------")
print("TIME BY IMAGE")
uploading_time = [x["upload_time"] for x in results if x["success"]]
processing_time = [x["process_time"] for x in results if x["success"]]
num_images = sum(x["num_files"] for x in results if x["success"])
print("Total images:", num_images)
print("Uploading time: {:.3f}s".format(sum(uploading_time) / num_images))
print("Processing time: {:.3f}s".format(sum(processing_time) / num_images))
print("--------------------------------------")