sbt-idp/scripts/crawl_database_by_time_with_accuracy_contrain.py

171 lines
5.6 KiB
Python
Raw Normal View History

2024-02-07 05:39:24 +00:00
import csv
from typing import Any
import psycopg2
import boto3
import os
from tqdm import tqdm
from datetime import datetime, timedelta
from pytz import timezone
from dotenv import load_dotenv
load_dotenv("../.env_prod")
# load_dotenv("../.env")
OUTPUT_NAME = "0131-0206"
START_DATE = datetime(2024, 1, 31, tzinfo=timezone('Asia/Singapore'))
END_DATE = datetime(2024, 2, 6, tzinfo=timezone('Asia/Singapore'))
BAD_THRESHOLD = 0.75
REVIEW_ACC_COL = 19
FEEDBACK_ACC_COL = 18
REQUEST_ID_COL = 6
# Database connection details
db_host = os.environ.get('DB_HOST', "")
# db_host = "42.96.42.13"
db_name = os.environ.get('DB_SCHEMA', "")
db_user = os.environ.get('DB_USER', "")
db_password = os.environ.get('DB_PASSWORD', "")
# S3 bucket details
s3_bucket_name = os.environ.get('S3_BUCKET_NAME', "")
s3_folder_prefix = 'sbt_invoice'
# S3 access credentials
access_key = os.environ.get('S3_ACCESS_KEY', "")
secret_key = os.environ.get('S3_SECRET_KEY', "")
class RequestAtt:
def __init__(self) -> None:
self.feedback_accuracy = []
self.reiviewed_accuracy = []
self.acc = 0
self.request_id = None
self.is_bad = False
self.data = []
def add_file(self, file):
self.data.append(file)
if file[REVIEW_ACC_COL]:
for key in file[REVIEW_ACC_COL].keys():
self.feedback_accuracy += file[REVIEW_ACC_COL][key]
if file[FEEDBACK_ACC_COL]:
for key in file[FEEDBACK_ACC_COL].keys():
self.feedback_accuracy += file[FEEDBACK_ACC_COL][key]
def is_bad_image(self):
fb = min(self.feedback_accuracy)/len(self.feedback_accuracy) if len(self.feedback_accuracy) else None
rv = min(self.reiviewed_accuracy)/len(self.reiviewed_accuracy) if len(self.reiviewed_accuracy) else None
if not fb and not rv:
self.is_bad = False
return False
elif fb and rv is None:
self.is_bad = fb < BAD_THRESHOLD
self.acc = fb
return fb < BAD_THRESHOLD
elif fb and rv:
self.is_bad = rv < BAD_THRESHOLD
self.acc = rv
return rv < BAD_THRESHOLD
return False
def get_request(cursor, request_in_id):
query = "SELECT * FROM fwd_api_subscriptionrequest WHERE id = %s"
cursor.execute(query, (request_in_id,))
data = cursor.fetchone()
return data if data else None
# Request IDs for filtering
def main():
# Connect to the PostgreSQL database
conn = psycopg2.connect(
host=db_host,
database=db_name,
user=db_user,
password=db_password
)
# Create a cursor
cursor = conn.cursor()
# Execute the SELECT query with the filter
query = "SELECT * FROM fwd_api_subscriptionrequestfile WHERE created_at >= %s AND created_at <= %s AND feedback_accuracy IS NOT NULL"
cursor.execute(query, (START_DATE, END_DATE))
# Fetch the filtered data
data = cursor.fetchall()
# Define the CSV file path
csv_file_path = f'{OUTPUT_NAME}.csv'
data_dict = {}
# Filter out requests request that has quality < 75%
for i, _d in enumerate(data):
if not data_dict.get(_d[REQUEST_ID_COL], None):
data_dict[_d[REQUEST_ID_COL]] = RequestAtt()
data_dict[_d[REQUEST_ID_COL]].request_id = _d[REQUEST_ID_COL]
data_dict[_d[REQUEST_ID_COL]].add_file(_d)
bad_images = []
for k in data_dict.keys():
if data_dict[k].is_bad_image():
bad_images.append(data_dict[k])
request_ids = []
# Write the data to the CSV file
for bad_image in bad_images:
request = get_request(cursor, bad_image.request_id)
if request:
request_ids.append(request[3])
# ###################### Get bad requests ######################
placeholders = ','.join(['%s'] * len(request_ids))
# Execute the SELECT query with the filter
query = f"SELECT * FROM fwd_api_subscriptionrequest WHERE request_id IN ({placeholders})"
cursor.execute(query, request_ids)
# Fetch the filtered data
data = cursor.fetchall()
# Define the CSV file path
csv_file_path = f'{OUTPUT_NAME}.csv'
# Write the data to the CSV file
with open(csv_file_path, 'w', newline='') as csv_file:
writer = csv.writer(csv_file)
writer.writerow([desc[0] for desc in cursor.description]) # Write column headers
writer.writerows(data) # Write the filtered data rows
# Close the cursor and database connection
cursor.close()
conn.close()
# Download folders from S3
s3_client = boto3.client(
's3',
aws_access_key_id=access_key,
aws_secret_access_key=secret_key
)
for request_id in tqdm(request_ids):
folder_key = f"{s3_folder_prefix}/{request_id}/" # Assuming folder structure like: s3_bucket_name/s3_folder_prefix/request_id/
local_folder_path = f"{OUTPUT_NAME}/{request_id}/" # Path to the local folder to save the downloaded files
os.makedirs(OUTPUT_NAME, exist_ok=True)
os.makedirs(local_folder_path, exist_ok=True)
# List objects in the S3 folder
response = s3_client.list_objects_v2(Bucket=s3_bucket_name, Prefix=folder_key)
objects = response.get('Contents', [])
for s3_object in objects:
object_key = s3_object['Key']
local_file_path = local_folder_path + object_key.split('/')[-1] # Extracting the file name from the object key
# Download the S3 object to the local file
s3_client.download_file(s3_bucket_name, object_key, local_file_path)
if __name__ == "__main__":
main()