From 52ba638eb7615851e86f5892069d55d125b4924c Mon Sep 17 00:00:00 2001 From: dx-tan Date: Fri, 8 Dec 2023 19:49:00 +0700 Subject: [PATCH] Add feedback API, parallel proccessing --- cope2n-api/fwd_api/api/ctel_view.py | 66 ++++++++++++++++++- .../fwd_api/celery_worker/internal_task.py | 25 +++++-- .../celery_worker/process_result_tasks.py | 44 ++++++++++++- .../fwd_api/models/SubscriptionRequest.py | 2 + cope2n-api/fwd_api/utils/FileUtils.py | 14 +++- cope2n-api/fwd_api/utils/ProcessUtil.py | 29 ++++++++ 6 files changed, 169 insertions(+), 11 deletions(-) diff --git a/cope2n-api/fwd_api/api/ctel_view.py b/cope2n-api/fwd_api/api/ctel_view.py index d253c7a..7e38ad1 100755 --- a/cope2n-api/fwd_api/api/ctel_view.py +++ b/cope2n-api/fwd_api/api/ctel_view.py @@ -67,6 +67,8 @@ class CtelViewSet(viewsets.ViewSet): total_page = 1 new_request: SubscriptionRequest = SubscriptionRequest(pages=total_page, + pages_left=total_page, + doc_type="all", process_type=p_type, status=1, request_id=rq_id, provider_code=provider_code, subscription=sub) @@ -91,7 +93,7 @@ class CtelViewSet(viewsets.ViewSet): print(f"[INFO]: Duration of Pre-processing: {j_time - s_time}s") print(f"[INFO]: b_url: {b_url}") if p_type in standard_ocr_list: - ProcessUtil.send_to_queue2(rq_id, sub.id, b_url, user.id, p_type) + ProcessUtil.send_to_queue2(rq_id + "_sub_0", sub.id, b_url, user.id, p_type) if p_type == ProcessType.TEMPLATE_MATCHING.value: ProcessUtil.send_template_queue(rq_id, b_url, validated_data['template'], user.id) else: @@ -149,6 +151,7 @@ class CtelViewSet(viewsets.ViewSet): list_urls = [] p_type = validated_data['type'] new_request: SubscriptionRequest = SubscriptionRequest(pages=total_page, + pages_left=total_page, process_type=p_type, status=1, request_id=rq_id, provider_code=provider_code, subscription=sub) @@ -226,6 +229,7 @@ class CtelViewSet(viewsets.ViewSet): list_urls = [] p_type = validated_data['type'] new_request: SubscriptionRequest = SubscriptionRequest(pages=total_page, + pages_left=total_page, process_type=p_type, status=1, request_id=rq_id, provider_code=provider_code, subscription=sub) @@ -286,6 +290,62 @@ class CtelViewSet(viewsets.ViewSet): return Response(status=status.HTTP_200_OK, data=serializer.data[0]) raise ServiceTimeoutException(excArgs=f"{rq_id}") + @extend_schema(request={ + 'multipart/form-data': { + 'type': 'object', + 'properties': { + 'request_id': { + 'type': 'string', + }, + 'retailername': { + 'type': 'string', + }, + 'sold_to_party': { + 'type': 'string', + }, + 'purchase_date': { + 'type': 'array', + 'items': { + 'type': 'string', + } + }, + 'imei_number': { + 'type': 'array', + 'items': { + 'type': 'string', + } + }, + }, + 'required': ['request_id', 'retailername', 'sold_to_party', 'purchase_date', 'imei_number'] + } + }, responses=None, tags=['ocr']) + @action(detail=False, url_path="images/feedback", methods=["POST"]) + # @transaction.atomic + def feedback(self, request): + # s_time = time.time() + # user_info = ProcessUtil.get_user(request) + # user = user_info.user + # sub = user_info.current_sub + + validated_data = ProcessUtil.sbt_validate_feedback(request) + rq_id = validated_data['request_id'] + + subcription_request = SubscriptionRequest.objects.filter(request_id=rq_id) + if len(subcription_request) == 0: + raise InvalidException(excArgs=f"{rq_id}") + subcription_request = subcription_request[0] + + # Save to database + subcription_request.feedback_result = validated_data + subcription_request.save() + file_name = f"feedback_{rq_id}.json" + # Save to local + file_path = FileUtils.save_json_file(file_name, subcription_request, validated_data) + # Upload to S3 + S3_path = FileUtils.save_to_S3(file_name, subcription_request, file_path) + + return JsonResponse(status=status.HTTP_200_OK, data={"request_id": rq_id}) + @extend_schema(request=None, responses=None, tags=['data']) @extend_schema(request=None, responses=None, tags=['templates'], methods=['GET']) @@ -400,6 +460,10 @@ class CtelViewSet(viewsets.ViewSet): # print(f"[DEBUG]: result: {serializer.data[0]}") if report_filter[0].status == 400: raise FileContentInvalidException() + if report_filter[0].status == 100: # continue, only return when result is fullfilled + empty_data = serializer.data[0] + empty_data["data"] = None + return Response(status=status.HTTP_200_OK, data=empty_data) return Response(status=status.HTTP_200_OK, data=serializer.data[0]) diff --git a/cope2n-api/fwd_api/celery_worker/internal_task.py b/cope2n-api/fwd_api/celery_worker/internal_task.py index d79bca8..9f311c3 100755 --- a/cope2n-api/fwd_api/celery_worker/internal_task.py +++ b/cope2n-api/fwd_api/celery_worker/internal_task.py @@ -81,7 +81,7 @@ def process_image_file(file_name: str, file_path, request, user) -> list: @app.task(name='do_pdf') def process_pdf(rq_id, sub_id, p_type, user_id, files): """ - pdf_files: [{ + files: [{ "file_name": "", "file_path": "", # local path to file "file_type": "" @@ -93,11 +93,16 @@ def process_pdf(rq_id, sub_id, p_type, user_id, files): new_request = SubscriptionRequest.objects.filter(request_id=rq_id)[0] user = UserProfile.objects.filter(id=user_id).first() b_urls = [] + new_request.pages = len(files) + new_request.pages_left = len(files) + for i, file in enumerate(files): extension = file["file_name"].split(".")[-1].lower() if extension == "pdf": _b_urls = process_pdf_file(file["file_name"], file["file_path"], new_request, user) if _b_urls is None: + new_request.status = 400 + new_request.save() raise FileContentInvalidException for j in range(len(_b_urls)): _b_urls[j]["doc_type"] = file["file_type"] @@ -113,10 +118,20 @@ def process_pdf(rq_id, sub_id, p_type, user_id, files): start_process = time.time() logger.info(f"BE proccessing time: {start_process - start}") - if p_type in standard_ocr_list: - ProcessUtil.send_to_queue2(rq_id, sub_id, b_urls, user_id, p_type) - if p_type == ProcessType.TEMPLATE_MATCHING.value: - ProcessUtil.send_template_queue(rq_id, b_urls, '', user_id) + # TODO: send to queue with different request_ids + doc_type_string ="" + for i, b_url in enumerate(b_urls): + fractorized_request_id = rq_id + f"_sub_{i}" + ProcessUtil.send_to_queue2(fractorized_request_id, sub_id, [b_url], user_id, p_type) + doc_type_string += "{},".format(b_url["doc_type"]) + doc_type_string = doc_type_string[:-1] + new_request.doc_type = doc_type_string + new_request.save() + + # if p_type in standard_ocr_list: + # ProcessUtil.send_to_queue2(rq_id, sub_id, b_urls, user_id, p_type) + # if p_type == ProcessType.TEMPLATE_MATCHING.value: + # ProcessUtil.send_template_queue(rq_id, b_urls, '', user_id) @app.task(name='upload_file_to_s3') def upload_file_to_s3(local_file_path, s3_key): diff --git a/cope2n-api/fwd_api/celery_worker/process_result_tasks.py b/cope2n-api/fwd_api/celery_worker/process_result_tasks.py index 4cc6a02..6e8c417 100755 --- a/cope2n-api/fwd_api/celery_worker/process_result_tasks.py +++ b/cope2n-api/fwd_api/celery_worker/process_result_tasks.py @@ -4,8 +4,30 @@ import uuid from fwd_api.celery_worker.worker import app from fwd_api.models import SubscriptionRequest from django.utils.crypto import get_random_string +from fwd_api.exception.exceptions import InvalidException +def aggregate_result(src_result, des_result, doc_type): + if src_result["status"] != 200: + return src_result + if not des_result: + return src_result + des_result["content"]["total_pages"] += 1 + des_result["content"]["ocr_num_pages"] += 1 + des_result["content"]["document"][0]["end_page"] += 1 + if doc_type == "imei": + des_result["content"]["document"][0]["content"][3]["value"] += src_result["content"]["document"][0]["content"][3]["value"] + elif doc_type == "invoice": + des_result["content"]["document"][0]["content"][0]["value"] = src_result["content"]["document"][0]["content"][0]["value"] + des_result["content"]["document"][0]["content"][1]["value"] = src_result["content"]["document"][0]["content"][1]["value"] + des_result["content"]["document"][0]["content"][2]["value"] += src_result["content"]["document"][0]["content"][2]["value"] + elif doc_type == "all": + des_result.update(src_result) + else: + raise InvalidException(f"doc_type: {doc_type}") + + return des_result + def print_id(rq_id): print(" [x] Received {rq}".format(rq=rq_id)) @@ -108,12 +130,28 @@ def process_invoice_sbt_result(rq_id, result): print_id(f"[DEBUG]: Received SBT request with id {rq_id}") print_id(f"[DEBUG]: result: {result}") try: + page_index = int(rq_id.split("_sub_")[1]) + rq_id = rq_id.split("_sub_")[0] rq: SubscriptionRequest = \ SubscriptionRequest.objects.filter(request_id=rq_id, process_type=ProcessType.SBT_INVOICE.value)[0] - status = to_status(result) + # status = to_status(result) + status = result.get("status", 200) + + rq.pages_left = rq.pages_left - 1 + done = rq.pages_left <= 0 + # aggregate results from multiple pages + rq.predict_result = aggregate_result(result, rq.predict_result, rq.doc_type.split(",")[page_index]) + + print_id(f"[DEBUG]: status: {status}") + + if status == 200: + if not done: + rq.status = 100 # continue + else: + rq.status = 200 # stop waiting + else: + rq.status = 404 # stop waiting - rq.predict_result = result - rq.status = status rq.save() update_user(rq) except IndexError as e: diff --git a/cope2n-api/fwd_api/models/SubscriptionRequest.py b/cope2n-api/fwd_api/models/SubscriptionRequest.py index 96505b1..47b2728 100755 --- a/cope2n-api/fwd_api/models/SubscriptionRequest.py +++ b/cope2n-api/fwd_api/models/SubscriptionRequest.py @@ -8,11 +8,13 @@ from fwd_api.models.Subscription import Subscription class SubscriptionRequest(models.Model): id = models.AutoField(primary_key=True) pages: int = models.IntegerField() + pages_left: int = models.IntegerField(default=1) doc_type: str = models.CharField(max_length=100) request_id = models.CharField(max_length=200) # Change to request_id process_type = models.CharField(max_length=200) # driver/id/invoice provider_code = models.CharField(max_length=200, default="Guest") # Request source FWD/CTel predict_result = models.JSONField(null=True) + feedback_result = models.JSONField(null=True) status = models.IntegerField() # 1: Processing(Pending) 2: PredictCompleted 3: ReturnCompleted subscription = models.ForeignKey(Subscription, on_delete=models.CASCADE) created_at = models.DateTimeField(default=timezone.now) diff --git a/cope2n-api/fwd_api/utils/FileUtils.py b/cope2n-api/fwd_api/utils/FileUtils.py index 59133a1..3e2a443 100755 --- a/cope2n-api/fwd_api/utils/FileUtils.py +++ b/cope2n-api/fwd_api/utils/FileUtils.py @@ -1,7 +1,8 @@ import io import os import traceback -import base64 +import base64 +import json from PIL import Image, ExifTags from django.core.files.uploadedfile import TemporaryUploadedFile @@ -77,7 +78,6 @@ def save_byte_file(file_name: str, rq: SubscriptionRequest, file_bytes): return file_path - def save_file(file_name: str, rq: SubscriptionRequest, file: TemporaryUploadedFile): folder_path = get_folder_path(rq) is_exist = os.path.exists(folder_path) @@ -93,6 +93,16 @@ def save_file(file_name: str, rq: SubscriptionRequest, file: TemporaryUploadedFi return file_path +def save_json_file(file_name: str, rq: SubscriptionRequest, data: dict): + folder_path = get_folder_path(rq) + is_exist = os.path.exists(folder_path) + if not is_exist: + # Create a new directory because it does not exist + os.makedirs(folder_path) + file_path = os.path.join(folder_path, file_name) + with open(file_path, "w") as json_file: + json.dump(data, json_file) + return file_path def delete_file_with_path(file_path: str) -> bool: try: diff --git a/cope2n-api/fwd_api/utils/ProcessUtil.py b/cope2n-api/fwd_api/utils/ProcessUtil.py index b6a6f0f..dd502a8 100755 --- a/cope2n-api/fwd_api/utils/ProcessUtil.py +++ b/cope2n-api/fwd_api/utils/ProcessUtil.py @@ -141,6 +141,35 @@ def sbt_validate_ocr_request_and_get(request, subscription): return validated_data +def sbt_validate_feedback(request): + validated_data = {} + + request_id = request.data.get('request_id', None) + retailername = request.data.get("retailername", None) + sold_to_party = request.data.get("sold_to_party", None) + purchase_date = request.data.getlist("purchase_date", []) + imei_number = request.data.getlist("imei_number", []) + + if not request_id: + raise RequiredFieldException(excArgs="request_id") + if not retailername: + raise RequiredFieldException(excArgs="retailername") + if not sold_to_party: + raise RequiredFieldException(excArgs="sold_to_party") + if len(purchase_date)==0: + raise RequiredFieldException(excArgs="purchase_date") + if len(imei_number)==0: + raise RequiredFieldException(excArgs="imei_number") + + + validated_data['request_id'] = request_id + validated_data['retailername'] = retailername + validated_data['sold_to_party'] = sold_to_party + validated_data['purchase_date'] = purchase_date + validated_data['imei_number'] = imei_number + + return validated_data + def count_pages_in_pdf(pdf_file): count = 0 fh, temp_filename = tempfile.mkstemp() # make a tmp file