import requests
import time
import argparse
import multiprocessing
import tqdm
import random
import traceback


parser = argparse.ArgumentParser()
parser.add_argument("--host", dest="host", default="https://sbt.idp.sdsrv.ai", required=False)
parser.add_argument("-u", "--username", help="Username to connect to server", required=True)
parser.add_argument("-p", "--password", help="Password to connect to server", required=True)
parser.add_argument("--num_requests", type=int, help="Number of requests", required=False, default=100)
parser.add_argument("--num_workers", type=int, help="Number of workers", required=False, default=3)
parser.add_argument("--checking_interval", type=float, help="Interval result checking time", required=False, default=0.5)
args = parser.parse_args()

PROCESSING_TIMEOUT = 60


# =================================================================
# GET THE TOKEN
response = requests.post(f'{args.host}/api/ctel/login/', json={
    'username': args.username,
    'password': args.password
})
try:
    token = response.json()['token']
except:
    print("Failed to login")
    print(response.content)
# After the login, store the token in the memory (RAM) or DB
# Re-login to issue a new token after 6 days.
# =================================================================

def process_file(data):
    files, token = data
    num_files = len(files)
    files.append(
        ('processType', (None, 12)),
    )
    # =================================================================
    # UPLOAD THE FILE
    start_time = time.time()
    try:
        response = requests.post(f'{args.host}/api/ctel/images/process_sync/', headers={
            'Authorization': token,
        }, files=files, timeout=300)
    except requests.exceptions.Timeout:
        print("Timeout occurred while uploading")
        return {
            "success": False,
            "status": "timeout",
            "upload_time": 0,
            "process_time": 0,
            "num_files": 0,
        }
    except Exception as e:
        print(e)
        traceback.print_exc()
        print("Unknown exception occurred while uploading")
        return {
            "success": False,
            "status": "unknown error",
            "upload_time": 0,
            "process_time": 0,
            "num_files": 0,
        }
    end_time = time.time()
    upload_time = end_time - start_time
    # =================================================================

    try:
        data = response.json()
        data.pop("files", None)
        print(data)
    except:
        print(response.content)
        return {
            "success": False,
            "status": "timeout",
            "upload_time": 0,
            "process_time": 0,
            "num_files": 0,
        }
    return {
        "success": True,
        "status": 200,
        "upload_time": upload_time,
        "process_time": upload_time,
        "num_files": num_files,
    }

# invoice_files = [
#     ('invoice_file', ('invoice.pdf', open("test_samples/20220303025923NHNE_20220222_Starhub_Order_Confirmation_by_Email.pdf", "rb").read())),
# ]
invoice_files = [
    ('invoice_file', ('invoice.jpg', open("test_samples/sbt/invoice.jpg", "rb").read())),
]
imei_files = [
    ('imei_files', ("test_samples/sbt/imei1.jpg", open("test_samples/sbt/imei1.jpg", "rb").read())),
    ('imei_files', ("test_samples/sbt/imei2.jpg", open("test_samples/sbt/imei2.jpg", "rb").read())),
]
# imei_files = [
#     ('imei_files', ("test_samples/sbt/imei1.jpg", open("test_samples/sbt/imei1.jpg", "rb").read())),
#     ('imei_files', ("test_samples/sbt/imei2.jpg", open("test_samples/sbt/imei2.jpg", "rb").read())),
# ]
def get_imei_files():
    # num_files = random.randint(1, len(imei_files) + 1)
    num_files = 2
    print("Num imeis", num_files)
    files = imei_files[:num_files]
    # print("Num of imei files:", len(files))
    return files
def get_files():
    return invoice_files + get_imei_files()
def gen_input(num_input):
    last = time.time()
    rate = 60/(12/3)
    for _ in range(num_input):
        interval = rate-(time.time()-last)
        if interval > 0:
            print(f"[INFO]: sleep for {interval}")
            time.sleep(interval)
        last = time.time()
        yield (get_files(), token)
pool = multiprocessing.Pool(processes=args.num_workers*5)
results = []
for result in tqdm.tqdm(pool.imap_unordered(process_file, gen_input(num_input=args.num_requests)), total=args.num_requests):
    results.append(result)

print("## TEST REPORT #################################")
print("Number of requests: {}".format(args.num_requests))
print("Number of concurrent requests: {}".format(args.num_workers))
print("Number of files: 1 invoice, 1-5 imei files (random)")
print("Query time interval for result: {:.3f}s ".format(args.checking_interval))
print("--------------------------------------")
print("SUCCESS RATE")
counter = {}
for result in results:
    counter[result["status"]] = counter.get(result["status"], 0) + 1
total_requests = sum(counter.values())
print("Success rate: {}".format(counter.get(200, 0) / total_requests if total_requests > 0 else -1))
print("Statuses:", counter)
print("--------------------------------------")
print("TIME BY REQUEST")
uploading_time = [x["upload_time"] for x in results if x["success"]]
if len(uploading_time) == 0:
    print("No valid uploading time")
    print("Check the results!")
processing_time = [x["process_time"] for x in results if x["success"]]
print("Uploading time (Avg / Min / Max): {:.3f}s {:.3f}s {:.3f}s".format(sum(uploading_time) / len(uploading_time), min(uploading_time), max(uploading_time)))
print("Processing time (Avg / Min / Max): {:.3f}s {:.3f}s {:.3f}s".format(sum(processing_time) / len(processing_time), min(processing_time), max(processing_time)))
print("--------------------------------------")
print("TIME BY IMAGE")
uploading_time = [x["upload_time"] for x in results if x["success"]]
processing_time = [x["process_time"] for x in results if x["success"]]
num_images = sum(x["num_files"] for x in results if x["success"])
print("Total images:", num_images)
print("Uploading time: {:.3f}s".format(sum(uploading_time) / num_images))
print("Processing time: {:.3f}s".format(sum(processing_time) / num_images))
print("--------------------------------------")