168 lines
6.8 KiB
Python
168 lines
6.8 KiB
Python
import csv
|
|
from typing import Any
|
|
import psycopg2
|
|
import boto3
|
|
import os
|
|
from tqdm import tqdm
|
|
from datetime import datetime, timedelta
|
|
from pytz import timezone
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv("../.env_prod")
|
|
# load_dotenv(".env_prod")
|
|
# load_dotenv("../.env")
|
|
|
|
OUTPUT_NAME = "0116-0216"
|
|
START_DATE = datetime(2024, 1, 16, tzinfo=timezone('Asia/Singapore'))
|
|
END_DATE = datetime(2024, 2, 16, tzinfo=timezone('Asia/Singapore'))
|
|
BAD_THRESHOLD = 0.75
|
|
# ("requestId", "redemptionNumber", "fileName", "userSubmitResults", "OCRResults", "revisedResults_by_SDSRV", "accuracy")
|
|
REQUEST_ID_COL = 3
|
|
REQUEST_NUMBER_COL = 6
|
|
REQUEST_REDEMPTION_COL = 27
|
|
FILE_NAME_COL = 1
|
|
OCR_RESULT_COL = 16
|
|
FEEDBACK_RESULT_COL = 15
|
|
REVIEWED_RESULT_COL = 17
|
|
|
|
REVIEW_ACC_COL = 19
|
|
FEEDBACK_ACC_COL = 18
|
|
|
|
# Database connection details
|
|
db_host = os.environ.get('DB_HOST', "")
|
|
db_name = os.environ.get('DB_SCHEMA', "")
|
|
db_user = os.environ.get('DB_USER', "")
|
|
db_password = os.environ.get('DB_PASSWORD', "")
|
|
# db_host = "sbt.cxetpslawu4p.ap-southeast-1.rds.amazonaws.com"
|
|
# db_name = "sbt2"
|
|
# db_user = "sbt"
|
|
# db_password = "sbtCH240"
|
|
|
|
# S3 bucket details
|
|
s3_bucket_name = os.environ.get('S3_BUCKET_NAME', "")
|
|
s3_folder_prefix = 'sbt_invoice'
|
|
|
|
# S3 access credentials
|
|
access_key = os.environ.get('S3_ACCESS_KEY', "")
|
|
secret_key = os.environ.get('S3_SECRET_KEY', "")
|
|
|
|
def get_request(cursor, request_in_id):
|
|
query = "SELECT * FROM fwd_api_subscriptionrequest WHERE id = %s"
|
|
cursor.execute(query, (request_in_id,))
|
|
data = cursor.fetchone()
|
|
return data if data else None
|
|
|
|
# Request IDs for filtering
|
|
def main():
|
|
# Connect to the PostgreSQL database
|
|
conn = psycopg2.connect(
|
|
host=db_host,
|
|
database=db_name,
|
|
user=db_user,
|
|
password=db_password
|
|
)
|
|
|
|
# Create a cursor
|
|
cursor = conn.cursor()
|
|
|
|
|
|
# Execute the SELECT query with the filter
|
|
query = "SELECT * FROM fwd_api_subscriptionrequestfile WHERE created_at >= %s AND created_at <= %s AND feedback_accuracy IS NOT NULL"
|
|
cursor.execute(query, (START_DATE, END_DATE))
|
|
|
|
# Fetch the filtered data
|
|
data = cursor.fetchall()
|
|
|
|
# Define the CSV file path
|
|
csv_file_path = f'{OUTPUT_NAME}.csv'
|
|
|
|
bad_image_list = [] # [("requestId", "redemptionNumber", "fileName", "userSubmitResults", "OCRResults", "revisedResults_by_SDSRV", "accuracy"), ...]
|
|
request_ids = [] # for crawling images
|
|
# Filter out requests request that has quality < 75%
|
|
for i, _d in enumerate(data):
|
|
if _d[FEEDBACK_ACC_COL] and _d[FEEDBACK_RESULT_COL]:
|
|
acc_len = 0
|
|
for key in _d[FEEDBACK_ACC_COL].keys():
|
|
if key == "purchase_date":
|
|
continue
|
|
acc_len += len(_d[FEEDBACK_ACC_COL][key])
|
|
if len(_d[FEEDBACK_ACC_COL][key]):
|
|
if min(_d[FEEDBACK_ACC_COL][key]) < BAD_THRESHOLD:
|
|
parent_request = get_request(cursor, _d[REQUEST_NUMBER_COL])
|
|
requestId = parent_request[REQUEST_ID_COL]
|
|
redemptionNumber = parent_request[REQUEST_REDEMPTION_COL]
|
|
fileName = _d[FILE_NAME_COL]
|
|
userSubmitResults = str(_d[FEEDBACK_RESULT_COL][key]) if _d[FEEDBACK_RESULT_COL] else ""
|
|
OCRResults = str(_d[OCR_RESULT_COL][key]) if _d[OCR_RESULT_COL] else ""
|
|
revisedResults_by_SDSRV = str(_d[REVIEWED_RESULT_COL][key]) if _d[REVIEWED_RESULT_COL] else ""
|
|
accuracy = _d[FEEDBACK_ACC_COL][key]
|
|
bad_image_list.append((requestId, redemptionNumber, fileName, userSubmitResults, OCRResults, revisedResults_by_SDSRV, accuracy))
|
|
request_ids.append(requestId)
|
|
if acc_len == 0: # This is the request with acc < 0.75
|
|
for key in _d[FEEDBACK_ACC_COL].keys():
|
|
if key == "purchase_date":
|
|
continue
|
|
# if not
|
|
if str(_d[FEEDBACK_RESULT_COL][key]) == str(_d[OCR_RESULT_COL][key]):
|
|
continue
|
|
parent_request = get_request(cursor, _d[REQUEST_NUMBER_COL])
|
|
requestId = parent_request[REQUEST_ID_COL]
|
|
redemptionNumber = parent_request[REQUEST_REDEMPTION_COL]
|
|
fileName = _d[FILE_NAME_COL]
|
|
userSubmitResults = str(_d[FEEDBACK_RESULT_COL][key]) if _d[FEEDBACK_RESULT_COL] else ""
|
|
OCRResults = str(_d[OCR_RESULT_COL][key]) if _d[OCR_RESULT_COL] else ""
|
|
revisedResults_by_SDSRV = str(_d[REVIEWED_RESULT_COL][key]) if _d[REVIEWED_RESULT_COL] else ""
|
|
accuracy = "Unknown (avg request acc < 0.75 is excluded from the acc report)"
|
|
bad_image_list.append((requestId, redemptionNumber, fileName, userSubmitResults, OCRResults, revisedResults_by_SDSRV, accuracy))
|
|
request_ids.append(requestId)
|
|
|
|
# Write the data to the CSV file
|
|
# for bad_image in bad_images:
|
|
# request = get_request(cursor, bad_image.request_id)
|
|
# if request:
|
|
# request_ids.append(request[3])
|
|
|
|
# ###################### Get bad requests ######################
|
|
# Define the CSV file path
|
|
csv_file_path = f'{OUTPUT_NAME}.csv'
|
|
|
|
# Write the data to the CSV file
|
|
with open(csv_file_path, 'w', newline='') as csv_file:
|
|
writer = csv.writer(csv_file)
|
|
writer.writerow(["requestId", "redemptionNumber", "fileName", "userSubmitResults", "OCRResults", "revisedResults_by_SDSRV", "accuracy"]) # Write column headers
|
|
writer.writerows(bad_image_list) # Write the filtered data rows
|
|
|
|
# Close the cursor and database connection
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
# Download folders from S3
|
|
s3_client = boto3.client(
|
|
's3',
|
|
aws_access_key_id=access_key,
|
|
aws_secret_access_key=secret_key
|
|
)
|
|
|
|
request_ids = list(set(request_ids))
|
|
|
|
for request_id in tqdm(request_ids):
|
|
folder_key = f"{s3_folder_prefix}/{request_id}/" # Assuming folder structure like: s3_bucket_name/s3_folder_prefix/request_id/
|
|
local_folder_path = f"{OUTPUT_NAME}/{request_id}/" # Path to the local folder to save the downloaded files
|
|
os.makedirs(OUTPUT_NAME, exist_ok=True)
|
|
os.makedirs(local_folder_path, exist_ok=True)
|
|
|
|
|
|
# List objects in the S3 folder
|
|
response = s3_client.list_objects_v2(Bucket=s3_bucket_name, Prefix=folder_key)
|
|
objects = response.get('Contents', [])
|
|
|
|
for s3_object in objects:
|
|
object_key = s3_object['Key']
|
|
local_file_path = local_folder_path + object_key.split('/')[-1] # Extracting the file name from the object key
|
|
|
|
# Download the S3 object to the local file
|
|
s3_client.download_file(s3_bucket_name, object_key, local_file_path)
|
|
|
|
if __name__ == "__main__":
|
|
main() |