sbt-idp/scripts/crawl_database_by_time_with_accuracy_contrain.py

168 lines
6.8 KiB
Python

import csv
from typing import Any
import psycopg2
import boto3
import os
from tqdm import tqdm
from datetime import datetime, timedelta
from pytz import timezone
from dotenv import load_dotenv
load_dotenv("../.env_prod")
# load_dotenv(".env_prod")
# load_dotenv("../.env")
OUTPUT_NAME = "0116-0216"
START_DATE = datetime(2024, 1, 16, tzinfo=timezone('Asia/Singapore'))
END_DATE = datetime(2024, 2, 16, tzinfo=timezone('Asia/Singapore'))
BAD_THRESHOLD = 0.75
# ("requestId", "redemptionNumber", "fileName", "userSubmitResults", "OCRResults", "revisedResults_by_SDSRV", "accuracy")
REQUEST_ID_COL = 3
REQUEST_NUMBER_COL = 6
REQUEST_REDEMPTION_COL = 27
FILE_NAME_COL = 1
OCR_RESULT_COL = 16
FEEDBACK_RESULT_COL = 15
REVIEWED_RESULT_COL = 17
REVIEW_ACC_COL = 19
FEEDBACK_ACC_COL = 18
# Database connection details
db_host = os.environ.get('DB_HOST', "")
db_name = os.environ.get('DB_SCHEMA', "")
db_user = os.environ.get('DB_USER', "")
db_password = os.environ.get('DB_PASSWORD', "")
# db_host = "sbt.cxetpslawu4p.ap-southeast-1.rds.amazonaws.com"
# db_name = "sbt2"
# db_user = "sbt"
# db_password = "sbtCH240"
# S3 bucket details
s3_bucket_name = os.environ.get('S3_BUCKET_NAME', "")
s3_folder_prefix = 'sbt_invoice'
# S3 access credentials
access_key = os.environ.get('S3_ACCESS_KEY', "")
secret_key = os.environ.get('S3_SECRET_KEY', "")
def get_request(cursor, request_in_id):
query = "SELECT * FROM fwd_api_subscriptionrequest WHERE id = %s"
cursor.execute(query, (request_in_id,))
data = cursor.fetchone()
return data if data else None
# Request IDs for filtering
def main():
# Connect to the PostgreSQL database
conn = psycopg2.connect(
host=db_host,
database=db_name,
user=db_user,
password=db_password
)
# Create a cursor
cursor = conn.cursor()
# Execute the SELECT query with the filter
query = "SELECT * FROM fwd_api_subscriptionrequestfile WHERE created_at >= %s AND created_at <= %s AND feedback_accuracy IS NOT NULL"
cursor.execute(query, (START_DATE, END_DATE))
# Fetch the filtered data
data = cursor.fetchall()
# Define the CSV file path
csv_file_path = f'{OUTPUT_NAME}.csv'
bad_image_list = [] # [("requestId", "redemptionNumber", "fileName", "userSubmitResults", "OCRResults", "revisedResults_by_SDSRV", "accuracy"), ...]
request_ids = [] # for crawling images
# Filter out requests request that has quality < 75%
for i, _d in enumerate(data):
if _d[FEEDBACK_ACC_COL] and _d[FEEDBACK_RESULT_COL]:
acc_len = 0
for key in _d[FEEDBACK_ACC_COL].keys():
if key == "purchase_date":
continue
acc_len += len(_d[FEEDBACK_ACC_COL][key])
if len(_d[FEEDBACK_ACC_COL][key]):
if min(_d[FEEDBACK_ACC_COL][key]) < BAD_THRESHOLD:
parent_request = get_request(cursor, _d[REQUEST_NUMBER_COL])
requestId = parent_request[REQUEST_ID_COL]
redemptionNumber = parent_request[REQUEST_REDEMPTION_COL]
fileName = _d[FILE_NAME_COL]
userSubmitResults = str(_d[FEEDBACK_RESULT_COL][key]) if _d[FEEDBACK_RESULT_COL] else ""
OCRResults = str(_d[OCR_RESULT_COL][key]) if _d[OCR_RESULT_COL] else ""
revisedResults_by_SDSRV = str(_d[REVIEWED_RESULT_COL][key]) if _d[REVIEWED_RESULT_COL] else ""
accuracy = _d[FEEDBACK_ACC_COL][key]
bad_image_list.append((requestId, redemptionNumber, fileName, userSubmitResults, OCRResults, revisedResults_by_SDSRV, accuracy))
request_ids.append(requestId)
if acc_len == 0: # This is the request with acc < 0.75
for key in _d[FEEDBACK_ACC_COL].keys():
if key == "purchase_date":
continue
# if not
if str(_d[FEEDBACK_RESULT_COL][key]) == str(_d[OCR_RESULT_COL][key]):
continue
parent_request = get_request(cursor, _d[REQUEST_NUMBER_COL])
requestId = parent_request[REQUEST_ID_COL]
redemptionNumber = parent_request[REQUEST_REDEMPTION_COL]
fileName = _d[FILE_NAME_COL]
userSubmitResults = str(_d[FEEDBACK_RESULT_COL][key]) if _d[FEEDBACK_RESULT_COL] else ""
OCRResults = str(_d[OCR_RESULT_COL][key]) if _d[OCR_RESULT_COL] else ""
revisedResults_by_SDSRV = str(_d[REVIEWED_RESULT_COL][key]) if _d[REVIEWED_RESULT_COL] else ""
accuracy = "Unknown (avg request acc < 0.75 is excluded from the acc report)"
bad_image_list.append((requestId, redemptionNumber, fileName, userSubmitResults, OCRResults, revisedResults_by_SDSRV, accuracy))
request_ids.append(requestId)
# Write the data to the CSV file
# for bad_image in bad_images:
# request = get_request(cursor, bad_image.request_id)
# if request:
# request_ids.append(request[3])
# ###################### Get bad requests ######################
# Define the CSV file path
csv_file_path = f'{OUTPUT_NAME}.csv'
# Write the data to the CSV file
with open(csv_file_path, 'w', newline='') as csv_file:
writer = csv.writer(csv_file)
writer.writerow(["requestId", "redemptionNumber", "fileName", "userSubmitResults", "OCRResults", "revisedResults_by_SDSRV", "accuracy"]) # Write column headers
writer.writerows(bad_image_list) # Write the filtered data rows
# Close the cursor and database connection
cursor.close()
conn.close()
# Download folders from S3
s3_client = boto3.client(
's3',
aws_access_key_id=access_key,
aws_secret_access_key=secret_key
)
request_ids = list(set(request_ids))
for request_id in tqdm(request_ids):
folder_key = f"{s3_folder_prefix}/{request_id}/" # Assuming folder structure like: s3_bucket_name/s3_folder_prefix/request_id/
local_folder_path = f"{OUTPUT_NAME}/{request_id}/" # Path to the local folder to save the downloaded files
os.makedirs(OUTPUT_NAME, exist_ok=True)
os.makedirs(local_folder_path, exist_ok=True)
# List objects in the S3 folder
response = s3_client.list_objects_v2(Bucket=s3_bucket_name, Prefix=folder_key)
objects = response.get('Contents', [])
for s3_object in objects:
object_key = s3_object['Key']
local_file_path = local_folder_path + object_key.split('/')[-1] # Extracting the file name from the object key
# Download the S3 object to the local file
s3_client.download_file(s3_bucket_name, object_key, local_file_path)
if __name__ == "__main__":
main()