Fix result accumulation

This commit is contained in:
Viet Anh Nguyen 2023-12-13 18:43:10 +07:00
parent bace56baf7
commit 6840f16ee7
15 changed files with 118 additions and 51 deletions

View File

@ -22,7 +22,7 @@ class CeleryConnector:
app = Celery(
"postman",
broker=env.str("CELERY_BROKER", "amqp://test:test@rabbitmq:5672"),
# backend="rpc://",
broker_transport_options={'confirm_publish': True},
)
def process_id_result(self, args):

View File

@ -19,6 +19,7 @@ class CeleryConnector:
app = Celery(
"postman",
broker= env.str("CELERY_BROKER", "amqp://test:test@rabbitmq:5672"),
broker_transport_options={'confirm_publish': True},
)
# mock task for FI

View File

@ -12,6 +12,7 @@ app: Celery = Celery(
include=[
"celery_worker.mock_process_tasks",
],
broker_transport_options={'confirm_publish': True},
)
task_exchange = Exchange("default", type="direct")
task_create_missing_queues = False

View File

@ -11,6 +11,7 @@ app: Celery = Celery(
include=[
"celery_worker.mock_process_tasks_fi",
],
broker_transport_options={'confirm_publish': True},
)
task_exchange = Exchange("default", type="direct")
task_create_missing_queues = False

View File

@ -11,3 +11,4 @@ easydict
imagesize==1.4.1
pdf2image==1.16.3
redis==5.0.1

View File

@ -44,6 +44,9 @@ S3_ENDPOINT = env.str("S3_ENDPOINT", "")
S3_ACCESS_KEY = env.str("S3_ACCESS_KEY", "TannedCung")
S3_SECRET_KEY = env.str("S3_SECRET_KEY", "TannedCung")
S3_BUCKET_NAME = env.str("S3_BUCKET_NAME", "ocr-data")
REDIS_HOST = env.str("REDIS_HOST", "result-cache")
REDIS_PORT = env.int("REDIS_PORT", 6379)
INSTALLED_APPS = [
"django.contrib.auth",

View File

@ -171,9 +171,7 @@ class CtelViewSet(viewsets.ViewSet):
while True:
current_time = time.time()
waiting_time = current_time - start_time
print("Waiting for: ", waiting_time)
if waiting_time > time_limit:
print("Timeout!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
break
time.sleep(0.2)
report_filter = SubscriptionRequest.objects.filter(request_id=rq_id)
@ -196,19 +194,12 @@ class CtelViewSet(viewsets.ViewSet):
if report_filter[0].status == 400:
raise FileContentInvalidException()
if report_filter[0].status == 100: # continue, only return when result is fullfilled
print(serializer.data)
print("Status Code: 100")
continue
if len(serializer.data) == 0:
print("No data found")
continue
if serializer.data[0].get("data", None) is None:
print(serializer.data[0])
print("No data[0] found")
continue
if serializer.data[0]["data"].get("status", 200) != 200:
print("No data status found")
continue
return Response(status=status.HTTP_200_OK, data=serializer.data[0])

View File

@ -37,6 +37,7 @@ class CeleryConnector:
app = Celery(
'postman',
broker=settings.BROKER_URL,
broker_transport_options={'confirm_publish': True},
)
def do_pdf(self, args):
return self.send_task('do_pdf', args)

View File

@ -1,28 +1,42 @@
import traceback
from copy import deepcopy
from fwd_api.celery_worker.worker import app
from fwd_api.models import SubscriptionRequest
from fwd_api.exception.exceptions import InvalidException
from fwd_api.models import SubscriptionRequest
from fwd_api.constant.common import ProcessType
from fwd_api.utils.RedisUtils import RedisUtils
redis_client = RedisUtils()
def aggregate_result(resutls, doc_types):
doc_types = doc_types.split(',')
des_result = deepcopy(list(resutls.values()))[0]
des_result["content"]["total_pages"] = 0
des_result["content"]["ocr_num_pages"] = 0
des_result["content"]["document"][0]["end_page"] = 0
des_result["content"]["document"][0]["content"][3]["value"] = [None for _ in range(doc_types.count("imei"))]
des_result["content"]["document"][0]["content"][2]["value"] = []
print(f"[INFO]: resutls: {resutls}")
for index, resutl in resutls.items():
index = int(index)
doc_type = doc_types[index]
def aggregate_result(src_result, des_result, doc_type):
if src_result["status"] != 200:
return src_result
if not des_result:
return src_result
des_result["content"]["total_pages"] += 1
des_result["content"]["ocr_num_pages"] += 1
des_result["content"]["document"][0]["end_page"] += 1
if doc_type == "imei":
des_result["content"]["document"][0]["content"][3]["value"] += src_result["content"]["document"][0]["content"][3]["value"]
des_result["content"]["document"][0]["content"][3]["value"][index] = resutl["content"]["document"][0]["content"][3]["value"][0]
elif doc_type == "invoice":
des_result["content"]["document"][0]["content"][0]["value"] = src_result["content"]["document"][0]["content"][0]["value"]
des_result["content"]["document"][0]["content"][1]["value"] = src_result["content"]["document"][0]["content"][1]["value"]
des_result["content"]["document"][0]["content"][2]["value"] += src_result["content"]["document"][0]["content"][2]["value"]
des_result["content"]["document"][0]["content"][0]["value"] = resutl["content"]["document"][0]["content"][0]["value"]
des_result["content"]["document"][0]["content"][1]["value"] = resutl["content"]["document"][0]["content"][1]["value"]
des_result["content"]["document"][0]["content"][2]["value"] += resutl["content"]["document"][0]["content"][2]["value"]
elif doc_type == "all":
des_result.update(src_result)
des_result.update(resutl)
else:
raise InvalidException(f"doc_type: {doc_type}")
@ -114,7 +128,6 @@ def process_invoice_manulife_result(rq_id, result):
@app.task(name='process_sbt_invoice_result')
def process_invoice_sbt_result(rq_id, result):
print_id(f"[DEBUG]: Received SBT request with id {rq_id}")
print_id(f"[DEBUG]: result: {result}")
try:
page_index = int(rq_id.split("_sub_")[1])
rq_id = rq_id.split("_sub_")[0]
@ -122,23 +135,23 @@ def process_invoice_sbt_result(rq_id, result):
SubscriptionRequest.objects.filter(request_id=rq_id, process_type=ProcessType.SBT_INVOICE.value)[0]
# status = to_status(result)
status = result.get("status", 200)
rq.pages_left = rq.pages_left - 1
done = rq.pages_left <= 0
# aggregate results from multiple pages
rq.predict_result = aggregate_result(result, rq.predict_result, rq.doc_type.split(",")[page_index])
print_id(f"[DEBUG]: status: {status}")
redis_client.set_cache(rq_id, page_index, result)
done = rq.pages == redis_client.get_size(rq_id)
if status == 200:
if not done:
rq.status = 100 # continue
else:
if done:
rq.status = 200 # stop waiting
results = redis_client.get_all_cache(rq_id)
rq.predict_result = aggregate_result(results, rq.doc_type)
print(f"[DEBUG]: rq.predict_result: {rq.predict_result}")
redis_client.remove_cache(rq_id)
rq.save()
else:
rq.status = 404 # stop waiting
rq.predict_result = result
redis_client.remove_cache(rq_id)
rq.save()
update_user(rq)
except IndexError as e:
print(e)

View File

@ -13,6 +13,7 @@ app: Celery = Celery(
'postman',
broker=settings.BROKER_URL,
include=['fwd_api.celery_worker.process_result_tasks', 'fwd_api.celery_worker.internal_task'],
broker_transport_options={'confirm_publish': True},
)
app.conf.update({

View File

@ -1,5 +1,5 @@
from rest_framework import status
from rest_framework.exceptions import APIException, ValidationError
from rest_framework.exceptions import APIException
from fwd import settings

View File

@ -0,0 +1,42 @@
import redis
import json
from django.conf import settings
class RedisUtils:
def __init__(self, host=settings.REDIS_HOST, port=settings.REDIS_PORT):
self.redis_client = redis.Redis(host=host, port=port, decode_responses=True)
def set_cache(self, request_id, image_index, data):
"""
request_id: str
data: dict
image_index: int
"""
self.redis_client.hset(request_id, image_index, json.dumps(data))
def get_all_cache(self, request_id):
resutlt = {}
for key, value in self.redis_client.hgetall(request_id).items():
resutlt[key] = json.loads(value)
return resutlt
def get_size(self, request_id):
return self.redis_client.hlen(request_id)
def remove_cache(self, request_id):
self.redis_client.delete(request_id)
if __name__ == '__main__':
_host = "127.0.0.1"
_port = 6379
Yujii_A = RedisUtils(_host, _port)
Yujii_A.set_cache("SAP123", 1, {"status": 1})
Yujii_A.set_cache("SAP123", 2, {"status": 2})
Yujii_A.set_cache("SAP123", 3, {"status": 3})
print("[INFO]: data for request_id: {}".format(Yujii_A.get_all_cache("SAP123")))
print("[INFO]: len for request_id: {}".format(Yujii_A.get_size("SAP123")))
Yujii_A.remove_cache("SAP123")
print("[INFO]: data for request_id: {}".format(Yujii_A.get_all_cache("SAP123")))
print("[INFO]: len for request_id: {}".format(Yujii_A.get_size("SAP123")))

View File

@ -49,3 +49,4 @@ djangorestframework-xml==2.0.0
boto3==1.29.7
imagesize==1.4.1
pdf2image==1.16.3
redis==5.0.1

View File

@ -102,6 +102,15 @@ services:
- ctel-sbt
command: server --address :9884 --console-address :9885 /data
result-cache:
image: redis:6.2-alpine
restart: always
command: redis-server --save 20 1 --loglevel warning
volumes:
- ./data/redis:/data
networks:
- ctel-sbt
be-celery-sbt:
# build:
# context: cope2n-api
@ -134,6 +143,9 @@ services:
- S3_SECRET_KEY=${S3_SECRET_KEY}
- S3_BUCKET_NAME=${S3_BUCKET_NAME}
- BASE_URL=http://be-ctel-sbt:${BASE_PORT}
- REDIS_HOST=result-cache
- REDIS_PORT=6379
networks:
- ctel-sbt
@ -148,7 +160,7 @@ services:
- ./cope2n-api:/app
working_dir: /app
command: sh -c "celery -A fwd_api.celery_worker.worker worker -l INFO --pool=solo"
command: sh -c "celery -A fwd_api.celery_worker.worker worker -l INFO -c 3"
# Back-end persistent
db-sbt:

View File

@ -94,12 +94,12 @@ def process_file(data):
invoice_files = [
('invoice_file', ('invoice.pdf', open("test_samples/20220303025923NHNE_20220222_Starhub_Order_Confirmation_by_Email.pdf", "rb").read())),
]
# invoice_files = [
# ('invoice_file', ('invoice.jpg', open("test_samples/sbt/invoice.jpg", "rb").read())),
# ('invoice_file', ('invoice.pdf', open("test_samples/20220303025923NHNE_20220222_Starhub_Order_Confirmation_by_Email.pdf", "rb").read())),
# ]
invoice_files = [
('invoice_file', ('invoice.jpg', open("test_samples/sbt/invoice.jpg", "rb").read())),
]
imei_files = [
('imei_files', ("test_samples/sbt/imei1.jpg", open("test_samples/sbt/invoice.jpg", "rb").read())),
('imei_files', ("test_samples/sbt/imei2.jpg", open("test_samples/sbt/imei2.jpg", "rb").read())),
@ -108,8 +108,7 @@ imei_files = [
('imei_files', ("test_samples/sbt/imei5.jpg", open("test_samples/sbt/imei5.jpg", "rb").read())),
]
def get_imei_files():
# num_files = random.randint(1, len(imei_files) + 1)
num_files = 1
num_files = random.randint(1, len(imei_files) + 1)
print("Num imeis", num_files)
files = imei_files[:num_files]
# print("Num of imei files:", len(files))