import csv
from typing import Any
import psycopg2
import boto3
import os
from tqdm import tqdm
from datetime import datetime, timedelta
from pytz import timezone

from dotenv import load_dotenv

load_dotenv("../.env_prod")
# load_dotenv(".env_prod")
# load_dotenv("../.env")

OUTPUT_NAME = "0116-0216"
START_DATE = datetime(2024, 1, 16, tzinfo=timezone('Asia/Singapore'))
END_DATE = datetime(2024, 2, 16, tzinfo=timezone('Asia/Singapore'))
BAD_THRESHOLD = 0.75
# ("requestId", "redemptionNumber", "fileName", "userSubmitResults", "OCRResults", "revisedResults_by_SDSRV", "accuracy")
REQUEST_ID_COL = 3
REQUEST_NUMBER_COL = 6
REQUEST_REDEMPTION_COL = 27
FILE_NAME_COL = 1
OCR_RESULT_COL = 16
FEEDBACK_RESULT_COL = 15
REVIEWED_RESULT_COL = 17

REVIEW_ACC_COL = 19
FEEDBACK_ACC_COL = 18

# Database connection details
db_host = os.environ.get('DB_HOST', "")
db_name = os.environ.get('DB_SCHEMA', "")
db_user = os.environ.get('DB_USER', "")
db_password = os.environ.get('DB_PASSWORD', "")
# db_host = "sbt.cxetpslawu4p.ap-southeast-1.rds.amazonaws.com"
# db_name = "sbt2"
# db_user = "sbt"
# db_password = "sbtCH240"

# S3 bucket details
s3_bucket_name = os.environ.get('S3_BUCKET_NAME', "")
s3_folder_prefix = 'sbt_invoice'

# S3 access credentials
access_key = os.environ.get('S3_ACCESS_KEY', "")
secret_key = os.environ.get('S3_SECRET_KEY', "")

def get_request(cursor, request_in_id):
    query = "SELECT * FROM fwd_api_subscriptionrequest WHERE id = %s"
    cursor.execute(query, (request_in_id,))
    data = cursor.fetchone()
    return data if data else None

# Request IDs for filtering
def main():
    # Connect to the PostgreSQL database
    conn = psycopg2.connect(
        host=db_host,
        database=db_name,
        user=db_user,
        password=db_password
    )

    # Create a cursor
    cursor = conn.cursor()


    # Execute the SELECT query with the filter
    query = "SELECT * FROM fwd_api_subscriptionrequestfile WHERE created_at >= %s AND created_at <= %s AND feedback_accuracy IS NOT NULL"
    cursor.execute(query, (START_DATE, END_DATE))

    # Fetch the filtered data
    data = cursor.fetchall()

    # Define the CSV file path
    csv_file_path = f'{OUTPUT_NAME}.csv'
    
    bad_image_list = [] # [("requestId", "redemptionNumber", "fileName", "userSubmitResults", "OCRResults", "revisedResults_by_SDSRV", "accuracy"), ...]
    request_ids = [] # for crawling images
    # Filter out requests request that has quality < 75%
    for i, _d in enumerate(data):
        if _d[FEEDBACK_ACC_COL] and _d[FEEDBACK_RESULT_COL]:
            acc_len = 0
            for key in _d[FEEDBACK_ACC_COL].keys():
                if key == "purchase_date":
                    continue
                acc_len += len(_d[FEEDBACK_ACC_COL][key])
                if len(_d[FEEDBACK_ACC_COL][key]):
                    if min(_d[FEEDBACK_ACC_COL][key]) < BAD_THRESHOLD:
                        parent_request = get_request(cursor, _d[REQUEST_NUMBER_COL])
                        requestId = parent_request[REQUEST_ID_COL]
                        redemptionNumber = parent_request[REQUEST_REDEMPTION_COL]
                        fileName = _d[FILE_NAME_COL]
                        userSubmitResults = str(_d[FEEDBACK_RESULT_COL][key]) if _d[FEEDBACK_RESULT_COL] else ""
                        OCRResults = str(_d[OCR_RESULT_COL][key]) if _d[OCR_RESULT_COL] else ""
                        revisedResults_by_SDSRV = str(_d[REVIEWED_RESULT_COL][key]) if _d[REVIEWED_RESULT_COL] else ""
                        accuracy = _d[FEEDBACK_ACC_COL][key]
                        bad_image_list.append((requestId, redemptionNumber, fileName, userSubmitResults, OCRResults, revisedResults_by_SDSRV, accuracy))
                        request_ids.append(requestId)
            if acc_len == 0: #  This is the request with acc < 0.75
                for key in _d[FEEDBACK_ACC_COL].keys():
                    if key == "purchase_date":
                        continue
                    # if not 
                    if str(_d[FEEDBACK_RESULT_COL][key]) == str(_d[OCR_RESULT_COL][key]):
                        continue
                    parent_request = get_request(cursor, _d[REQUEST_NUMBER_COL])
                    requestId = parent_request[REQUEST_ID_COL]
                    redemptionNumber = parent_request[REQUEST_REDEMPTION_COL]
                    fileName = _d[FILE_NAME_COL]
                    userSubmitResults = str(_d[FEEDBACK_RESULT_COL][key]) if _d[FEEDBACK_RESULT_COL] else ""
                    OCRResults = str(_d[OCR_RESULT_COL][key]) if _d[OCR_RESULT_COL] else ""
                    revisedResults_by_SDSRV = str(_d[REVIEWED_RESULT_COL][key]) if _d[REVIEWED_RESULT_COL] else ""
                    accuracy = "Unknown (avg request acc < 0.75 is excluded from the acc report)"
                    bad_image_list.append((requestId, redemptionNumber, fileName, userSubmitResults, OCRResults, revisedResults_by_SDSRV, accuracy))
                    request_ids.append(requestId)

    # Write the data to the CSV file
    # for bad_image in bad_images:
    #     request = get_request(cursor, bad_image.request_id)
    #     if request:
    #         request_ids.append(request[3])

    # ###################### Get bad requests ######################
    # Define the CSV file path
    csv_file_path = f'{OUTPUT_NAME}.csv'

    # Write the data to the CSV file
    with open(csv_file_path, 'w', newline='') as csv_file:
        writer = csv.writer(csv_file)
        writer.writerow(["requestId", "redemptionNumber", "fileName", "userSubmitResults", "OCRResults", "revisedResults_by_SDSRV", "accuracy"])  # Write column headers
        writer.writerows(bad_image_list)  # Write the filtered data rows

    # Close the cursor and database connection
    cursor.close()
    conn.close()

    # Download folders from S3
    s3_client = boto3.client(
        's3',
        aws_access_key_id=access_key,
        aws_secret_access_key=secret_key
    )

    request_ids = list(set(request_ids))

    for request_id in tqdm(request_ids):
        folder_key = f"{s3_folder_prefix}/{request_id}/"  # Assuming folder structure like: s3_bucket_name/s3_folder_prefix/request_id/
        local_folder_path = f"{OUTPUT_NAME}/{request_id}/"  # Path to the local folder to save the downloaded files
        os.makedirs(OUTPUT_NAME, exist_ok=True)
        os.makedirs(local_folder_path, exist_ok=True)

        
        # List objects in the S3 folder
        response = s3_client.list_objects_v2(Bucket=s3_bucket_name, Prefix=folder_key)
        objects = response.get('Contents', [])
        
        for s3_object in objects:
            object_key = s3_object['Key']
            local_file_path = local_folder_path + object_key.split('/')[-1]  # Extracting the file name from the object key
            
            # Download the S3 object to the local file
            s3_client.download_file(s3_bucket_name, object_key, local_file_path)

if __name__ == "__main__":
    main()