from common.utils_invoice.run_ocr import ocr_predict import os from Kie_Hoanglv.prediction2 import KIEInvoiceInfer from configs.config_invoice.layoutxlm_base_invoice import * from PIL import Image import requests from io import BytesIO import numpy as np import cv2 model = KIEInvoiceInfer( weight_dir=TRAINED_DIR, tokenizer_dir=TOKENIZER_DIR, max_seq_len=MAX_SEQ_LENGTH, classes=KIE_LABELS, device=DEVICE, outdir_visualize=VISUALIZE_DIR, ) def format_result(result): """ return: [ { key: 'name', value: 'Nguyen Hoang Hiep', true_box: [ 373, 113, 700, 420 ] }, { key: 'name', value: 'Nguyen Hoang Hiep 1', true_box: [ 10, 10, 20, 20, ] }, ] """ new_result = [] for i, item in enumerate(result[0]): new_result.append( { "key": item, "value": result[0][item], "true_box": result[1][i], } ) return new_result def predict(image_url): if not os.path.exists(PRED_DIR): os.makedirs(PRED_DIR, exist_ok=True) if not os.path.exists(VISUALIZE_DIR): os.makedirs(VISUALIZE_DIR, exist_ok=True) response = requests.get(image_url) image = Image.open(BytesIO(response.content)) cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) bboxes, texts = ocr_predict(cv_image) texts_replaced = [] for text in texts: if "✪" in text: text_replaced = text.replace("✪", " ") texts_replaced.append(text_replaced) else: texts_replaced.append(text) inputs = model.prepare_kie_inputs(image, ocr_info=[bboxes, texts_replaced]) result = model(inputs) result = format_result(result) return result