import requests
import time
import random
import argparse
import multiprocessing
import tqdm
import traceback
from requests_toolbelt import  MultipartEncoderMonitor
import requests


parser = argparse.ArgumentParser()
parser.add_argument("--host", dest="host", default="https://sbt.idp.sdsrv.ai", required=False)
parser.add_argument("-u", "--username", help="Username to connect to server", required=True)
parser.add_argument("-p", "--password", help="Password to connect to server", required=True)
parser.add_argument("--num_requests", type=int, help="Number of requests", required=False, default=100)
parser.add_argument("--num_workers", type=int, help="Number of workers", required=False, default=3)
parser.add_argument("--checking_interval", type=float, help="Interval result checking time", required=False, default=0.5)
args = parser.parse_args()

PROCESSING_TIMEOUT = 60


# =================================================================
# GET THE TOKEN
response = requests.post(f'{args.host}/api/ctel/login/', json={
    'username': args.username,
    'password': args.password
})
try:
    token = response.json()['token']
except:
    print("Failed to login")
    print(response.content)

# After the login, store the token in the memory (RAM) or DB
# Re-login to issue a new token after 6 days.
# =================================================================

def process_file(data):
    _, token = data
    files =  []
    if random.random() < 0.2:
        files =  [
            ('invoice_file', ("invoice.jpg", open("test_samples/sbt/invoice.jpg", "rb"), 'application/octet-stream')),
            ('imei_files', ("imei1.jpg", open("test_samples/sbt/imei1.jpg", "rb"), 'application/octet-stream')),
            ('imei_files', ("imei1.jpg", open("test_samples/sbt/imei2.jpg", "rb"), 'application/octet-stream')),
        ]
    elif random.random() < 0.6:
        files =  [
            ('imei_files', ("imei1.jpg", open("test_samples/sbt/imei1.jpg", "rb"), 'application/octet-stream')),
        ]
    else:
        files =  [
            ('imei_files', ("imei1.jpg", open("test_samples/sbt/imei1.jpg", "rb"), 'application/octet-stream')),
            ('imei_files', ("imei1.jpg", open("test_samples/sbt/imei2.jpg", "rb"), 'application/octet-stream')),
        ]
    num_files = len(files)
    files.append(('processType', '12'))
    # =================================================================
    # UPLOAD THE FILE
    start_time = time.time()
    end_of_upload_time = 0
    def my_callback(monitor):
        nonlocal end_of_upload_time
        if monitor.bytes_read == monitor.len:
            end_of_upload_time = time.time()
    m = MultipartEncoderMonitor.from_fields(
        fields=files,
        callback=my_callback
    )
    try:
        response = requests.post(f'{args.host}/api/ctel/images/process_sync/', headers={
            'Authorization': token,
            'Content-Type': m.content_type
        }, data=m, timeout=300)
    except requests.exceptions.Timeout:
        print("Timeout occurred while uploading")
        return {
            "success": False,
            "status": "timeout",
            "upload_time": 0,
            "process_time": 0,
            "num_files": 0,
        }
    except Exception as e:
        print(e)
        traceback.print_exc()
        print("Unknown exception occurred while uploading")
        return {
            "success": False,
            "status": "unknown error",
            "upload_time": 0,
            "process_time": 0,
            "num_files": 0,
        }
    end_time = time.time()
    upload_time = end_of_upload_time - start_time
    # =================================================================

    try:
        data = response.json()
        if len(data["files"]) != num_files:
            return {
            "success": False,
            "status": "missing_file",
            "upload_time": 0,
            "process_time": 0,
            "num_files": 0,
        }
        data.pop("files", None)
    except:
        print(response.content)
        return {
            "success": False,
            "status": "timeout",
            "upload_time": 0,
            "process_time": 0,
            "num_files": 0,
        }
    return {
        "success": True,
        "status": 200,
        "upload_time": upload_time,
        "process_time": time.time() - start_time - upload_time,
        "num_files": num_files,
    }


def gen_input(num_input):
    for _ in range(num_input):
        yield (None, token)
pool = multiprocessing.Pool(processes=args.num_workers)
results = []
for result in tqdm.tqdm(pool.imap_unordered(process_file, gen_input(num_input=args.num_requests)), total=args.num_requests):
    results.append(result)

print("## TEST REPORT #################################")
print("Number of requests: {}".format(args.num_requests))
print("Number of concurrent requests: {}".format(args.num_workers))
print("--------------------------------------")
print("SUCCESS RATE")
counter = {}
for result in results:
    counter[result["status"]] = counter.get(result["status"], 0) + 1
total_requests = sum(counter.values())
print("Success rate: {}".format(counter.get(200, 0) / total_requests if total_requests > 0 else -1))
print("Statuses:", counter)
print("--------------------------------------")
print("TIME BY REQUEST")
uploading_time = [x["upload_time"] for x in results if x["success"]]
if len(uploading_time) == 0:
    print("No valid uploading time")
    print("Check the results!")
processing_time = [x["process_time"] for x in results if x["success"]]
print("Uploading time (Avg / Min / Max): {:.3f}s {:.3f}s {:.3f}s".format(sum(uploading_time) / len(uploading_time), min(uploading_time), max(uploading_time)))
print("Processing time (Avg / Min / Max): {:.3f}s {:.3f}s {:.3f}s".format(sum(processing_time) / len(processing_time), min(processing_time), max(processing_time)))
print("--------------------------------------")
print("TIME BY IMAGE")
uploading_time = [x["upload_time"] for x in results if x["success"]]
processing_time = [x["process_time"] for x in results if x["success"]]
num_images = sum(x["num_files"] for x in results if x["success"])
print("Total images:", num_images)
print("Uploading + Processing time: {:.3f}s".format(sum(processing_time) / num_images))
print("--------------------------------------")