Add: subsidiary for getting seller

This commit is contained in:
dx-tan 2024-04-05 18:50:41 +07:00
parent 131c63454a
commit b860f1cd4f
8 changed files with 24 additions and 22 deletions

View File

@ -27,7 +27,7 @@ sbt_engine = load_engine(kvu_cfg)
kvu_cfg["option"] = option
def sbt_predict(image_url, engine) -> None:
def sbt_predict(image_url, engine, metadata={}) -> None:
req = urllib.request.urlopen(image_url)
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
img = cv2.imdecode(arr, -1)
@ -37,16 +37,18 @@ def sbt_predict(image_url, engine) -> None:
# image_path = os.path.join(save_dir, f"{image_url}.jpg")
os.makedirs(save_dir, exist_ok = True)
tmp_image_path = os.path.join(save_dir, f"{uuid.uuid4()}.jpg")
cv2.imwrite(tmp_image_path, img)
cv2.imwrite(tmp_image_path, img)
extra_params = {'sub': metadata.get("subsidiary", None)} # example of 'AU'
outputs = process_img(img=tmp_image_path,
save_dir=save_dir,
engine=engine,
export_all=False, # False
option=option)
option=option,
extra_params=extra_params)
os.remove(tmp_image_path)
return outputs
def predict(page_numb, image_url):
def predict(page_numb, image_url, metadata={}):
"""
module predict function
@ -70,7 +72,7 @@ def predict(page_numb, image_url):
dict: output of model
"""
sbt_result = sbt_predict(image_url, engine=sbt_engine)
sbt_result = sbt_predict(image_url, engine=sbt_engine, metadata=metadata)
output_dict = {
"document_type": "invoice",
"document_class": " ",

View File

@ -62,7 +62,7 @@ def process_sbt_invoice(rq_id, list_url, metadata):
# TODO: simply returning 200 and 404 doesn't make any sense
c_connector = CeleryConnector()
try:
result = compile_output_sbt(list_url)
result = compile_output_sbt(list_url, metadata)
metadata['ai_inference_profile'] = result.pop("inference_profile")
hoadon = {"status": 200, "content": result, "message": "Success"}
print(hoadon)

View File

@ -200,7 +200,7 @@ def compile_output_manulife(list_url):
}
return results
def compile_output_sbt(list_url):
def compile_output_sbt(list_url, metadata):
"""_summary_
Args:
@ -231,7 +231,7 @@ def compile_output_sbt(list_url):
start = time.time()
pages_predict_time = []
for page in list_url:
output_model = predict_sbt(page['page_number'], page['file_url'])
output_model = predict_sbt(page['page_number'], page['file_url'], metadata)
pages_predict_time.append(time.time())
if "doc_type" in page:
output_model['doc_type'] = page['doc_type']

@ -1 +1 @@
Subproject commit 46a612a003c411406988b83b3dd6299d2a458366
Subproject commit b7baf4954e592068288c606376b035b41dd9e319

View File

@ -686,7 +686,7 @@ def _acc_will_be_ignored(key_name, _target, type):
else:
return False
def calculate_accuracy(key_name: str, inference: Dict[str, Union[str, List]], target: Dict[str, Union[str, List]], type: str):
def calculate_accuracy(key_name: str, inference: Dict[str, Union[str, List]], target: Dict[str, Union[str, List]], type: str, sub: str):
"""_summary_
NOTE: This has been changed to return accuracy = None if
Args:
@ -708,8 +708,8 @@ def calculate_accuracy(key_name: str, inference: Dict[str, Union[str, List]], ta
for i, v in enumerate(_inference):
# TODO: target[i] is None, ""
x = post_processing_str(key_name, _inference[i], is_gt=False)
y = post_processing_str(key_name, _target[i], is_gt=True)
x = post_processing_str(key_name, _inference[i], is_gt=False, sub=sub)
y = post_processing_str(key_name, _target[i], is_gt=True, sub=sub)
score = eval_ocr_metric(
[x],
@ -959,8 +959,8 @@ def calculate_subcription_file(subcription_request_file):
for key_name in valid_keys:
try:
att["acc"]["feedback"][key_name], att["normalized_data"]["feedback"][key_name] = calculate_accuracy(key_name, inference_result, feedback_result, "feedback")
att["acc"]["reviewed"][key_name], att["normalized_data"]["reviewed"][key_name] = calculate_accuracy(key_name, inference_result, reviewed_result, "reviewed")
att["acc"]["feedback"][key_name], att["normalized_data"]["feedback"][key_name] = calculate_accuracy(key_name, inference_result, feedback_result, "feedback", sub=subcription_request_file.request.subsidiary)
att["acc"]["reviewed"][key_name], att["normalized_data"]["reviewed"][key_name] = calculate_accuracy(key_name, inference_result, reviewed_result, "reviewed", sub=subcription_request_file.request.subsidiary)
except Exception as e:
att["err"].append(str(e))
subcription_request_file.feedback_accuracy = att["acc"]["feedback"]

View File

@ -40,27 +40,27 @@ def convert_datetime_format(date_string: str, is_gt=False) -> str:
return date_string
def normalise_retailer_name(retailer: str):
def normalise_retailer_name(retailer: str, sub: str):
input_value = {
"text": retailer,
"id": 0,
"class": "seller",
"bbox": [0, 0, 0, 0],
}
output = get_seller({'seller': [input_value]})
output = get_seller({'seller': [input_value]}, sub)
norm_seller_name = post_process_seller(output)
norm_seller_name = post_process_seller(output, sub)
return norm_seller_name
def post_processing_str(class_name: str, s: str, is_gt: bool) -> str:
def post_processing_str(class_name: str, s: str, is_gt: bool, sub: str) -> str:
s = str(s).replace('', ' ').strip()
if s.lower() in ['null', 'nan', "none"]:
return ''
if class_name == "purchase_date" and is_gt == True:
s = convert_datetime_format(s)
if class_name == "retailername":
s = normalise_retailer_name(s)
s = normalise_retailer_name(s, sub)
return s

@ -1 +1 @@
Subproject commit 20128037dbfca217fa3d3ca4551cc7f8ae8a190e
Subproject commit b7baf4954e592068288c606376b035b41dd9e319

View File

@ -179,8 +179,8 @@ services:
- ./cope2n-api:/app
working_dir: /app
command: sh -c "celery -A fwd_api.celery_worker.worker worker -l INFO -c 5"
# command: bash -c "tail -f > /dev/null"
# command: sh -c "celery -A fwd_api.celery_worker.worker worker -l INFO -c 5"
command: bash -c "tail -f > /dev/null"
# Back-end persistent
db-sbt: