Add: subsidiary for getting seller

This commit is contained in:
dx-tan 2024-04-05 18:50:41 +07:00
parent 131c63454a
commit b860f1cd4f
8 changed files with 24 additions and 22 deletions

View File

@ -27,7 +27,7 @@ sbt_engine = load_engine(kvu_cfg)
kvu_cfg["option"] = option kvu_cfg["option"] = option
def sbt_predict(image_url, engine) -> None: def sbt_predict(image_url, engine, metadata={}) -> None:
req = urllib.request.urlopen(image_url) req = urllib.request.urlopen(image_url)
arr = np.asarray(bytearray(req.read()), dtype=np.uint8) arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
img = cv2.imdecode(arr, -1) img = cv2.imdecode(arr, -1)
@ -37,16 +37,18 @@ def sbt_predict(image_url, engine) -> None:
# image_path = os.path.join(save_dir, f"{image_url}.jpg") # image_path = os.path.join(save_dir, f"{image_url}.jpg")
os.makedirs(save_dir, exist_ok = True) os.makedirs(save_dir, exist_ok = True)
tmp_image_path = os.path.join(save_dir, f"{uuid.uuid4()}.jpg") tmp_image_path = os.path.join(save_dir, f"{uuid.uuid4()}.jpg")
cv2.imwrite(tmp_image_path, img) cv2.imwrite(tmp_image_path, img)
extra_params = {'sub': metadata.get("subsidiary", None)} # example of 'AU'
outputs = process_img(img=tmp_image_path, outputs = process_img(img=tmp_image_path,
save_dir=save_dir, save_dir=save_dir,
engine=engine, engine=engine,
export_all=False, # False export_all=False, # False
option=option) option=option,
extra_params=extra_params)
os.remove(tmp_image_path) os.remove(tmp_image_path)
return outputs return outputs
def predict(page_numb, image_url): def predict(page_numb, image_url, metadata={}):
""" """
module predict function module predict function
@ -70,7 +72,7 @@ def predict(page_numb, image_url):
dict: output of model dict: output of model
""" """
sbt_result = sbt_predict(image_url, engine=sbt_engine) sbt_result = sbt_predict(image_url, engine=sbt_engine, metadata=metadata)
output_dict = { output_dict = {
"document_type": "invoice", "document_type": "invoice",
"document_class": " ", "document_class": " ",

View File

@ -62,7 +62,7 @@ def process_sbt_invoice(rq_id, list_url, metadata):
# TODO: simply returning 200 and 404 doesn't make any sense # TODO: simply returning 200 and 404 doesn't make any sense
c_connector = CeleryConnector() c_connector = CeleryConnector()
try: try:
result = compile_output_sbt(list_url) result = compile_output_sbt(list_url, metadata)
metadata['ai_inference_profile'] = result.pop("inference_profile") metadata['ai_inference_profile'] = result.pop("inference_profile")
hoadon = {"status": 200, "content": result, "message": "Success"} hoadon = {"status": 200, "content": result, "message": "Success"}
print(hoadon) print(hoadon)

View File

@ -200,7 +200,7 @@ def compile_output_manulife(list_url):
} }
return results return results
def compile_output_sbt(list_url): def compile_output_sbt(list_url, metadata):
"""_summary_ """_summary_
Args: Args:
@ -231,7 +231,7 @@ def compile_output_sbt(list_url):
start = time.time() start = time.time()
pages_predict_time = [] pages_predict_time = []
for page in list_url: for page in list_url:
output_model = predict_sbt(page['page_number'], page['file_url']) output_model = predict_sbt(page['page_number'], page['file_url'], metadata)
pages_predict_time.append(time.time()) pages_predict_time.append(time.time())
if "doc_type" in page: if "doc_type" in page:
output_model['doc_type'] = page['doc_type'] output_model['doc_type'] = page['doc_type']

@ -1 +1 @@
Subproject commit 46a612a003c411406988b83b3dd6299d2a458366 Subproject commit b7baf4954e592068288c606376b035b41dd9e319

View File

@ -686,7 +686,7 @@ def _acc_will_be_ignored(key_name, _target, type):
else: else:
return False return False
def calculate_accuracy(key_name: str, inference: Dict[str, Union[str, List]], target: Dict[str, Union[str, List]], type: str): def calculate_accuracy(key_name: str, inference: Dict[str, Union[str, List]], target: Dict[str, Union[str, List]], type: str, sub: str):
"""_summary_ """_summary_
NOTE: This has been changed to return accuracy = None if NOTE: This has been changed to return accuracy = None if
Args: Args:
@ -708,8 +708,8 @@ def calculate_accuracy(key_name: str, inference: Dict[str, Union[str, List]], ta
for i, v in enumerate(_inference): for i, v in enumerate(_inference):
# TODO: target[i] is None, "" # TODO: target[i] is None, ""
x = post_processing_str(key_name, _inference[i], is_gt=False) x = post_processing_str(key_name, _inference[i], is_gt=False, sub=sub)
y = post_processing_str(key_name, _target[i], is_gt=True) y = post_processing_str(key_name, _target[i], is_gt=True, sub=sub)
score = eval_ocr_metric( score = eval_ocr_metric(
[x], [x],
@ -959,8 +959,8 @@ def calculate_subcription_file(subcription_request_file):
for key_name in valid_keys: for key_name in valid_keys:
try: try:
att["acc"]["feedback"][key_name], att["normalized_data"]["feedback"][key_name] = calculate_accuracy(key_name, inference_result, feedback_result, "feedback") att["acc"]["feedback"][key_name], att["normalized_data"]["feedback"][key_name] = calculate_accuracy(key_name, inference_result, feedback_result, "feedback", sub=subcription_request_file.request.subsidiary)
att["acc"]["reviewed"][key_name], att["normalized_data"]["reviewed"][key_name] = calculate_accuracy(key_name, inference_result, reviewed_result, "reviewed") att["acc"]["reviewed"][key_name], att["normalized_data"]["reviewed"][key_name] = calculate_accuracy(key_name, inference_result, reviewed_result, "reviewed", sub=subcription_request_file.request.subsidiary)
except Exception as e: except Exception as e:
att["err"].append(str(e)) att["err"].append(str(e))
subcription_request_file.feedback_accuracy = att["acc"]["feedback"] subcription_request_file.feedback_accuracy = att["acc"]["feedback"]

View File

@ -40,27 +40,27 @@ def convert_datetime_format(date_string: str, is_gt=False) -> str:
return date_string return date_string
def normalise_retailer_name(retailer: str): def normalise_retailer_name(retailer: str, sub: str):
input_value = { input_value = {
"text": retailer, "text": retailer,
"id": 0, "id": 0,
"class": "seller", "class": "seller",
"bbox": [0, 0, 0, 0], "bbox": [0, 0, 0, 0],
} }
output = get_seller({'seller': [input_value]}) output = get_seller({'seller': [input_value]}, sub)
norm_seller_name = post_process_seller(output) norm_seller_name = post_process_seller(output, sub)
return norm_seller_name return norm_seller_name
def post_processing_str(class_name: str, s: str, is_gt: bool) -> str: def post_processing_str(class_name: str, s: str, is_gt: bool, sub: str) -> str:
s = str(s).replace('', ' ').strip() s = str(s).replace('', ' ').strip()
if s.lower() in ['null', 'nan', "none"]: if s.lower() in ['null', 'nan', "none"]:
return '' return ''
if class_name == "purchase_date" and is_gt == True: if class_name == "purchase_date" and is_gt == True:
s = convert_datetime_format(s) s = convert_datetime_format(s)
if class_name == "retailername": if class_name == "retailername":
s = normalise_retailer_name(s) s = normalise_retailer_name(s, sub)
return s return s

@ -1 +1 @@
Subproject commit 20128037dbfca217fa3d3ca4551cc7f8ae8a190e Subproject commit b7baf4954e592068288c606376b035b41dd9e319

View File

@ -179,8 +179,8 @@ services:
- ./cope2n-api:/app - ./cope2n-api:/app
working_dir: /app working_dir: /app
command: sh -c "celery -A fwd_api.celery_worker.worker worker -l INFO -c 5" # command: sh -c "celery -A fwd_api.celery_worker.worker worker -l INFO -c 5"
# command: bash -c "tail -f > /dev/null" command: bash -c "tail -f > /dev/null"
# Back-end persistent # Back-end persistent
db-sbt: db-sbt: