import cv2 from common.ocr import ocr_predict from common.crop_location import crop_location from Kie_AHung.prediction import infer_driving_license from Kie_AHung_ID.prediction import infer_id_card from common.post_processing_datetime import DatetimeCorrector from transformers import ( LayoutXLMTokenizer, LayoutLMv2FeatureExtractor, LayoutXLMProcessor ) MAX_SEQ_LENGTH = 512 # TODO Fix this hard code tokenizer = LayoutXLMTokenizer.from_pretrained( "./Kie_AHung/model/pretrained/layoutxlm-base/tokenizer", model_max_length=MAX_SEQ_LENGTH ) feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False) processor = LayoutXLMProcessor(feature_extractor, tokenizer) max_n_words = 100 def predict(rq_id, sub_id, folder_name, list_url, user_id, infer_name): """Predict text from image Args: image_path (str): path to image Returns: dict: dict result of prediction """ results = { "model":{ "name":infer_name, "confidence": 1.0, "type": "finance/invoice", "isValid": True, "shape": "letter", } } compile_outputs = [] for page in list_url: image_location = crop_location(page['file_url']) if image_location is None: compile_output = { 'page_index': page['page_number'], 'path_image_croped': None, 'request_file_id': page['request_file_id'], 'fields': None } compile_outputs.append(compile_output) elif image_location is not None: path_image_croped = "/app/media/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id) cv2.imwrite("/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id), image_location) list_line = ocr_predict(image_location) if infer_name == "driving_license": from common.post_processing_driver import create_kie_dict _, _, _, list_words = infer_driving_license(image_location, list_line, max_n_words, processor) result_dict = create_kie_dict(list_words) elif infer_name == "id_card": from common.post_processing_id import create_kie_dict _, _, _, list_words = infer_id_card(image_location, list_line, max_n_words, processor) result_dict = create_kie_dict(list_words) fields = [] for kie_label in result_dict: if result_dict[kie_label]["text"] != "": if kie_label == "Date Range": text = DatetimeCorrector.correct(result_dict[kie_label]["text"]) else: text = result_dict[kie_label]["text"] field = { "label": kie_label, "value": text.replace("✪", " ") if "✪" in text else text, "box": result_dict[kie_label]["bbox"], "confidence": 0.99 #TODO: add confidence } fields.append(field) compile_output = { 'page_index': page['page_number'], 'path_image_croped': str(path_image_croped), 'request_file_id': page['request_file_id'], 'fields': fields } compile_outputs.append(compile_output) results['pages'] = compile_outputs return results