import csv from typing import Any import psycopg2 import boto3 import os from tqdm import tqdm from datetime import datetime, timedelta from pytz import timezone from dotenv import load_dotenv load_dotenv("../.env_prod") # load_dotenv(".env_prod") # load_dotenv("../.env") OUTPUT_NAME = "0116-0216" START_DATE = datetime(2024, 1, 16, tzinfo=timezone('Asia/Singapore')) END_DATE = datetime(2024, 2, 16, tzinfo=timezone('Asia/Singapore')) BAD_THRESHOLD = 0.75 # ("requestId", "redemptionNumber", "fileName", "userSubmitResults", "OCRResults", "revisedResults_by_SDSRV", "accuracy") REQUEST_ID_COL = 3 REQUEST_NUMBER_COL = 6 REQUEST_REDEMPTION_COL = 27 FILE_NAME_COL = 1 OCR_RESULT_COL = 16 FEEDBACK_RESULT_COL = 15 REVIEWED_RESULT_COL = 17 REVIEW_ACC_COL = 19 FEEDBACK_ACC_COL = 18 # Database connection details db_host = os.environ.get('DB_HOST', "") db_name = os.environ.get('DB_SCHEMA', "") db_user = os.environ.get('DB_USER', "") db_password = os.environ.get('DB_PASSWORD', "") # db_host = "sbt.cxetpslawu4p.ap-southeast-1.rds.amazonaws.com" # db_name = "sbt2" # db_user = "sbt" # db_password = "sbtCH240" # S3 bucket details s3_bucket_name = os.environ.get('S3_BUCKET_NAME', "") s3_folder_prefix = 'sbt_invoice' # S3 access credentials access_key = os.environ.get('S3_ACCESS_KEY', "") secret_key = os.environ.get('S3_SECRET_KEY', "") def get_request(cursor, request_in_id): query = "SELECT * FROM fwd_api_subscriptionrequest WHERE id = %s" cursor.execute(query, (request_in_id,)) data = cursor.fetchone() return data if data else None # Request IDs for filtering def main(): # Connect to the PostgreSQL database conn = psycopg2.connect( host=db_host, database=db_name, user=db_user, password=db_password ) # Create a cursor cursor = conn.cursor() # Execute the SELECT query with the filter query = "SELECT * FROM fwd_api_subscriptionrequestfile WHERE created_at >= %s AND created_at <= %s AND feedback_accuracy IS NOT NULL" cursor.execute(query, (START_DATE, END_DATE)) # Fetch the filtered data data = cursor.fetchall() # Define the CSV file path csv_file_path = f'{OUTPUT_NAME}.csv' bad_image_list = [] # [("requestId", "redemptionNumber", "fileName", "userSubmitResults", "OCRResults", "revisedResults_by_SDSRV", "accuracy"), ...] request_ids = [] # for crawling images # Filter out requests request that has quality < 75% for i, _d in enumerate(data): if _d[FEEDBACK_ACC_COL] and _d[FEEDBACK_RESULT_COL]: acc_len = 0 for key in _d[FEEDBACK_ACC_COL].keys(): if key == "purchase_date": continue acc_len += len(_d[FEEDBACK_ACC_COL][key]) if len(_d[FEEDBACK_ACC_COL][key]): if min(_d[FEEDBACK_ACC_COL][key]) < BAD_THRESHOLD: parent_request = get_request(cursor, _d[REQUEST_NUMBER_COL]) requestId = parent_request[REQUEST_ID_COL] redemptionNumber = parent_request[REQUEST_REDEMPTION_COL] fileName = _d[FILE_NAME_COL] userSubmitResults = str(_d[FEEDBACK_RESULT_COL][key]) if _d[FEEDBACK_RESULT_COL] else "" OCRResults = str(_d[OCR_RESULT_COL][key]) if _d[OCR_RESULT_COL] else "" revisedResults_by_SDSRV = str(_d[REVIEWED_RESULT_COL][key]) if _d[REVIEWED_RESULT_COL] else "" accuracy = _d[FEEDBACK_ACC_COL][key] bad_image_list.append((requestId, redemptionNumber, fileName, userSubmitResults, OCRResults, revisedResults_by_SDSRV, accuracy)) request_ids.append(requestId) if acc_len == 0: # This is the request with acc < 0.75 for key in _d[FEEDBACK_ACC_COL].keys(): if key == "purchase_date": continue # if not if str(_d[FEEDBACK_RESULT_COL][key]) == str(_d[OCR_RESULT_COL][key]): continue parent_request = get_request(cursor, _d[REQUEST_NUMBER_COL]) requestId = parent_request[REQUEST_ID_COL] redemptionNumber = parent_request[REQUEST_REDEMPTION_COL] fileName = _d[FILE_NAME_COL] userSubmitResults = str(_d[FEEDBACK_RESULT_COL][key]) if _d[FEEDBACK_RESULT_COL] else "" OCRResults = str(_d[OCR_RESULT_COL][key]) if _d[OCR_RESULT_COL] else "" revisedResults_by_SDSRV = str(_d[REVIEWED_RESULT_COL][key]) if _d[REVIEWED_RESULT_COL] else "" accuracy = "Unknown (avg request acc < 0.75 is excluded from the acc report)" bad_image_list.append((requestId, redemptionNumber, fileName, userSubmitResults, OCRResults, revisedResults_by_SDSRV, accuracy)) request_ids.append(requestId) # Write the data to the CSV file # for bad_image in bad_images: # request = get_request(cursor, bad_image.request_id) # if request: # request_ids.append(request[3]) # ###################### Get bad requests ###################### # Define the CSV file path csv_file_path = f'{OUTPUT_NAME}.csv' # Write the data to the CSV file with open(csv_file_path, 'w', newline='') as csv_file: writer = csv.writer(csv_file) writer.writerow(["requestId", "redemptionNumber", "fileName", "userSubmitResults", "OCRResults", "revisedResults_by_SDSRV", "accuracy"]) # Write column headers writer.writerows(bad_image_list) # Write the filtered data rows # Close the cursor and database connection cursor.close() conn.close() # Download folders from S3 s3_client = boto3.client( 's3', aws_access_key_id=access_key, aws_secret_access_key=secret_key ) request_ids = list(set(request_ids)) for request_id in tqdm(request_ids): folder_key = f"{s3_folder_prefix}/{request_id}/" # Assuming folder structure like: s3_bucket_name/s3_folder_prefix/request_id/ local_folder_path = f"{OUTPUT_NAME}/{request_id}/" # Path to the local folder to save the downloaded files os.makedirs(OUTPUT_NAME, exist_ok=True) os.makedirs(local_folder_path, exist_ok=True) # List objects in the S3 folder response = s3_client.list_objects_v2(Bucket=s3_bucket_name, Prefix=folder_key) objects = response.get('Contents', []) for s3_object in objects: object_key = s3_object['Key'] local_file_path = local_folder_path + object_key.split('/')[-1] # Extracting the file name from the object key # Download the S3 object to the local file s3_client.download_file(s3_bucket_name, object_key, local_file_path) if __name__ == "__main__": main()