Add: Logging to AI services
This commit is contained in:
parent
26decf38c4
commit
963940089e
@ -81,7 +81,6 @@ def predict_image(img_path: str, save_dir: str, predictor: KVUPredictor, process
|
||||
# vat_outputs_invoice = export_kvu_for_all(os.path.join(save_dir, fname.replace(img_ext, '.json')), lwords[i], bbox[i], pr_class_words[i], pr_relations[i], predictor.class_names)
|
||||
vat_outputs_invoice = export_kvu_for_manulife(os.path.join(save_dir, fname.replace(img_ext, '.json')), lwords[i], bbox[i], pr_class_words[i], pr_relations[i], predictor.class_names)
|
||||
|
||||
print(vat_outputs_invoice)
|
||||
return vat_outputs_invoice
|
||||
|
||||
|
||||
@ -105,7 +104,6 @@ def show_groundtruth(dir_path: str, json_dir: str, save_dir: str, predictor: KVU
|
||||
list_images = []
|
||||
for ext in ['JPG', 'PNG', 'jpeg', 'jpg', 'png']:
|
||||
list_images += glob.glob(os.path.join(dir_path, f'*.{ext}'))
|
||||
print('No. images:', len(list_images))
|
||||
for img_path in tqdm(list_images):
|
||||
load_groundtruth(img_path, json_dir, save_dir, predictor, processor, export_img)
|
||||
|
||||
@ -134,4 +132,3 @@ if __name__ == "__main__":
|
||||
image_path = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
|
||||
save_dir = "/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/visualize/test"
|
||||
predict_image(image_path, save_dir, predictor, processor)
|
||||
print('[INFO] Done')
|
@ -7,6 +7,13 @@ from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from lightning_modules.data_modules.kvu_dataset import KVUDataset, KVUEmbeddingDataset
|
||||
from lightning_modules.utils import _get_number_samples
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class KVUDataModule(pl.LightningDataModule):
|
||||
def __init__(self, cfg, tokenizer_layoutxlm, feature_extractor):
|
||||
@ -61,7 +68,7 @@ class KVUDataModule(pl.LightningDataModule):
|
||||
f"Not supported stage: {self.cfg.stage}"
|
||||
)
|
||||
|
||||
print('No. training samples:', len(dataset))
|
||||
logger.info('No. training samples:', len(dataset))
|
||||
|
||||
data_loader = DataLoader(
|
||||
dataset,
|
||||
@ -72,7 +79,7 @@ class KVUDataModule(pl.LightningDataModule):
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"Elapsed time for loading training data: {elapsed_time}")
|
||||
logger.info(f"Elapsed time for loading training data: {elapsed_time}")
|
||||
|
||||
return data_loader
|
||||
|
||||
@ -101,7 +108,7 @@ class KVUDataModule(pl.LightningDataModule):
|
||||
f"Not supported stage: {self.cfg.stage}"
|
||||
)
|
||||
|
||||
print('No. validation samples:', len(dataset))
|
||||
logger.info('No. validation samples:', len(dataset))
|
||||
|
||||
data_loader = DataLoader(
|
||||
dataset,
|
||||
|
@ -18,6 +18,13 @@ import json
|
||||
import os
|
||||
import numpy as np
|
||||
from typing import Union, Tuple, List
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
current_dir = os.getcwd()
|
||||
|
||||
|
||||
@ -42,10 +49,10 @@ def get_args():
|
||||
|
||||
|
||||
def load_engine(opt) -> OcrEngine:
|
||||
print("[INFO] Loading engine...")
|
||||
logger.info("Loading engine...")
|
||||
kw = json.loads(opt.ocr_kwargs) if opt.ocr_kwargs else {}
|
||||
engine = OcrEngine(**kw)
|
||||
print("[INFO] Engine loaded")
|
||||
logger.info("[INFO] Engine loaded")
|
||||
return engine
|
||||
|
||||
|
||||
@ -64,7 +71,7 @@ def get_paths_from_opt(opt) -> Tuple[Path, Path]:
|
||||
Path(save_dir), Path(base_dir))
|
||||
if not save_dir.exists():
|
||||
save_dir.mkdir()
|
||||
print("[INFO]: Creating folder ", save_dir)
|
||||
logger.info("Creating folder ", save_dir)
|
||||
return input_image, save_dir
|
||||
|
||||
|
||||
@ -105,7 +112,7 @@ def process_dir(
|
||||
img_path.stem + ".txt"))
|
||||
process_img(img, save_path, engine, export_img)
|
||||
except Exception as e:
|
||||
print('[ERROR]: ', e, ' at ', simg_path)
|
||||
logger.error(e, ' at ', simg_path)
|
||||
continue
|
||||
ddata["img_path"].append(simg_path)
|
||||
ddata["ocr_path"].append(save_path)
|
||||
@ -125,7 +132,6 @@ def process_csv(csv_path: str, engine: OcrEngine) -> None:
|
||||
if __name__ == "__main__":
|
||||
opt = get_args()
|
||||
engine = load_engine(opt)
|
||||
print("[INFO]: OCR engine settings:", engine.settings)
|
||||
img, save_dir = get_paths_from_opt(opt)
|
||||
|
||||
lskip_dir = []
|
||||
@ -137,7 +143,6 @@ if __name__ == "__main__":
|
||||
elif img.suffix in ImageReader.supported_ext:
|
||||
process_img(str(img), save_dir, engine, opt.export_img)
|
||||
elif img.suffix == '.csv':
|
||||
print("[WARNING]: Running with csv file will ignore the save_dir argument. Instead, the ocr_path in the csv would be used")
|
||||
process_csv(img, engine)
|
||||
else:
|
||||
raise NotImplementedError('[ERROR]: Unsupported file {}'.format(img))
|
||||
|
@ -3,7 +3,13 @@ from typing import Optional, List
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from .utils import visualize_bbox_and_label
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Box:
|
||||
def __init__(self, x1, y1, x2, y2, conf=-1., label=""):
|
||||
@ -189,7 +195,7 @@ class Word_group:
|
||||
if word.text != "✪":
|
||||
for w in self.list_words:
|
||||
if word.word_id == w.word_id:
|
||||
print("Word id collision")
|
||||
logger.info("Word id collision")
|
||||
return False
|
||||
word.word_group_id = self.word_group_id #
|
||||
word.line_id = self.line_id
|
||||
@ -260,7 +266,7 @@ class Line:
|
||||
if word_group.list_words is not None:
|
||||
for wg in self.list_word_groups:
|
||||
if word_group.word_group_id == wg.word_group_id:
|
||||
print("Word_group id collision")
|
||||
logger.info("Word_group id collision")
|
||||
return False
|
||||
|
||||
self.list_word_groups.append(word_group)
|
||||
@ -352,7 +358,7 @@ class Paragraph:
|
||||
if line.list_word_groups is not None:
|
||||
for l in self.list_lines:
|
||||
if line.line_id == l.line_id:
|
||||
print("Line id collision")
|
||||
logger.info("Line id collision")
|
||||
return False
|
||||
for i in range(len(line.list_word_groups)):
|
||||
line.list_word_groups[
|
||||
|
@ -16,7 +16,13 @@ from .dto import Word, Line, Page, Document, Box
|
||||
# from .word_formation import wo rds_to_lines_mmocr as words_to_lines
|
||||
from .word_formation import words_to_lines_tesseract as words_to_lines
|
||||
DEFAULT_SETTING_PATH = str(Path(__file__).parents[1]) + "/settings.yml"
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class OcrEngine:
|
||||
def __init__(self, settings_file: str = DEFAULT_SETTING_PATH, **kwargs: dict):
|
||||
@ -35,7 +41,7 @@ class OcrEngine:
|
||||
|
||||
if "cuda" in self.__settings["device"]:
|
||||
if not torch.cuda.is_available():
|
||||
print("[WARNING]: CUDA is not available, running with cpu instead")
|
||||
logger.warning("CUDA is not available, running with cpu instead")
|
||||
self.__settings["device"] = "cpu"
|
||||
self._detector = StandaloneYOLOXRunner(
|
||||
version=self.__settings["detector"],
|
||||
|
@ -12,7 +12,13 @@ from pdf2image import convert_from_path
|
||||
from deskew import determine_skew
|
||||
from jdeskew.estimator import get_angle
|
||||
from jdeskew.utility import rotate as jrotate
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def post_process_recog(text: str) -> str:
|
||||
text = text.replace("✪", " ")
|
||||
@ -30,7 +36,7 @@ class Timer:
|
||||
def __exit__(self, func: Callable, *args):
|
||||
self.end_time = time.perf_counter()
|
||||
self.elapsed_time = self.end_time - self.start_time
|
||||
print(f"[INFO]: {self.name} took : {self.elapsed_time:.6f} seconds")
|
||||
logger.info(f"{self.name} took : {self.elapsed_time:.6f} seconds")
|
||||
|
||||
|
||||
def rotate(
|
||||
@ -201,8 +207,8 @@ class ImageReader:
|
||||
ImageReader.validate_img_path(img_path)
|
||||
limgs.append(ImageReader._read(img_path))
|
||||
except (FileNotFoundError, NotImplementedError, IsADirectoryError) as e:
|
||||
print("[ERROR]: ", e)
|
||||
print("[INFO]: Skipping image {}".format(img_path))
|
||||
logger.error(e)
|
||||
logger.error("Skipping image {}".format(img_path))
|
||||
return limgs
|
||||
|
||||
@staticmethod
|
||||
|
@ -2,6 +2,14 @@ from builtins import dict
|
||||
from .dto import Word, Line, Word_group, Box
|
||||
import numpy as np
|
||||
from typing import Optional, List, Tuple, Union
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MIN_IOU_HEIGHT = 0.7
|
||||
MIN_WIDTH_LINE_RATIO = 0.05
|
||||
|
||||
@ -485,7 +493,7 @@ def near(word_group1: Word_group, word_group2: Word_group):
|
||||
if overlap > 0:
|
||||
return True
|
||||
if abs(overlap / min_height) < 1.5:
|
||||
print("near enough", abs(overlap / min_height), overlap, min_height)
|
||||
logger.info("near enough", abs(overlap / min_height), overlap, min_height)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -9,7 +9,14 @@ sys.path.append('/mnt/ssd1T/tuanlv/02.KeyValueUnderstanding/') # TODO: ???????
|
||||
from lightning_modules.classifier_module import parse_initial_words, parse_subsequent_words, parse_relations
|
||||
from model import get_model
|
||||
from utils import load_model_weight
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class KVUPredictor:
|
||||
def __init__(self, configs, class_names, dummy_idx, mode=0):
|
||||
@ -20,9 +27,9 @@ class KVUPredictor:
|
||||
self.dummy_idx = dummy_idx
|
||||
self.mode = mode
|
||||
|
||||
print('[INFO] Loading Key-Value Understanding model ...')
|
||||
logger.info('Loading Key-Value Understanding model ...')
|
||||
self.net, cfg, self.backbone_type = self._load_model(cfg_path, ckpt_path)
|
||||
print("[INFO] Loaded model")
|
||||
logger.info("Loaded model")
|
||||
|
||||
if mode == 3:
|
||||
self.max_window_count = cfg.train.max_window_count
|
||||
@ -41,7 +48,7 @@ class KVUPredictor:
|
||||
cfg.stage = self.mode
|
||||
backbone_type = cfg.model.backbone
|
||||
|
||||
print('[INFO] Checkpoint:', ckpt_path)
|
||||
logger.info('Checkpoint:', ckpt_path)
|
||||
net = get_model(cfg)
|
||||
load_model_weight(net, ckpt_path)
|
||||
net.to('cuda')
|
||||
|
@ -6,7 +6,13 @@ from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||
from pytorch_lightning.plugins import DDPPlugin
|
||||
from utils.ema_callbacks import EMA
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _update_config(cfg):
|
||||
cfg.save_weight_dir = os.path.join(cfg.workspace, "checkpoints")
|
||||
@ -14,7 +20,7 @@ def _update_config(cfg):
|
||||
|
||||
# set per-gpu batch size
|
||||
num_devices = torch.cuda.device_count()
|
||||
print('No. devices:', num_devices)
|
||||
logger.info('No. devices:', num_devices)
|
||||
for mode in ["train", "val"]:
|
||||
new_batch_size = cfg[mode].batch_size // num_devices
|
||||
cfg[mode].batch_size = new_batch_size
|
||||
@ -89,15 +95,15 @@ def create_exp_dir(save_dir=''):
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
else:
|
||||
print("DIR already existed.")
|
||||
print('Experiment dir : {}'.format(save_dir))
|
||||
logger.info("DIR already existed.")
|
||||
logger.info('Experiment dir : {}'.format(save_dir))
|
||||
|
||||
def create_dir(save_dir=''):
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
else:
|
||||
print("DIR already existed.")
|
||||
print('Save dir : {}'.format(save_dir))
|
||||
logger.info("DIR already existed.")
|
||||
logger.info('Save dir : {}'.format(save_dir))
|
||||
|
||||
def load_checkpoint(ckpt_path, model, key_include):
|
||||
assert os.path.exists(ckpt_path) == True, f"Ckpt path at {ckpt_path} not exist!"
|
||||
@ -109,7 +115,7 @@ def load_checkpoint(ckpt_path, model, key_include):
|
||||
state_dict[key[4:].replace(key_include + '.', "")] = state_dict[key] # remove net.something.
|
||||
del state_dict[key]
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
print(f"Load checkpoint at {ckpt_path}")
|
||||
logger.info(f"Load checkpoint at {ckpt_path}")
|
||||
return model
|
||||
|
||||
def load_model_weight(net, pretrained_model_file):
|
||||
|
@ -10,25 +10,31 @@ from pdf2image import convert_from_path
|
||||
from dicttoxml import dicttoxml
|
||||
from word_preprocess import vat_standardizer, get_string, ap_standardizer
|
||||
from kvu_dictionary import vat_dictionary, ap_dictionary
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_dir(save_dir=''):
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
else:
|
||||
print("DIR already existed.")
|
||||
print('Save dir : {}'.format(save_dir))
|
||||
logger.info("DIR already existed.")
|
||||
logger.info('Save dir : {}'.format(save_dir))
|
||||
|
||||
def pdf2image(pdf_dir, save_dir):
|
||||
pdf_files = glob.glob(f'{pdf_dir}/*.pdf')
|
||||
print('No. pdf files:', len(pdf_files))
|
||||
logger.info('No. pdf files:', len(pdf_files))
|
||||
|
||||
for file in tqdm(pdf_files):
|
||||
pages = convert_from_path(file, 500)
|
||||
for i, page in enumerate(pages):
|
||||
page.save(os.path.join(save_dir, os.path.basename(file).replace('.pdf', f'_{i}.jpg')), 'JPEG')
|
||||
print('Done!!!')
|
||||
logger.info('Done!!!')
|
||||
|
||||
def xyxy2xywh(bbox):
|
||||
return [
|
||||
@ -239,7 +245,7 @@ def matched_wordgroup_relations(word_groups:dict, lrelations: list) -> list:
|
||||
try:
|
||||
outputs.append([word_groups[wg_from], word_groups[wg_to]])
|
||||
except:
|
||||
print('Not valid pair:', wg_from, wg_to)
|
||||
logger.info('Not valid pair:', wg_from, wg_to)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -257,7 +263,7 @@ def export_kvu_outputs(file_path, lwords, class_words, lrelations, labels=['othe
|
||||
triplet_pairs = []
|
||||
single_pairs = []
|
||||
table = []
|
||||
# print('key2values_relations', key2values_relations)
|
||||
# logger.info('key2values_relations', key2values_relations)
|
||||
for key_group_id, list_value_group_ids in key2values_relations.items():
|
||||
if len(list_value_group_ids) == 0: continue
|
||||
elif len(list_value_group_ids) == 1:
|
||||
@ -343,7 +349,7 @@ def export_kvu_for_VAT_invoice(file_path, lwords, class_words, lrelations, label
|
||||
for pair in outputs['single']:
|
||||
for key_name, value in pair.items():
|
||||
key_name, score, proceessed_text = ap_standardizer(key_name, threshold=0.8, header=False)
|
||||
# print(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
# logger.info(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs):
|
||||
single_pairs[key_name].append({
|
||||
@ -352,8 +358,8 @@ def export_kvu_for_VAT_invoice(file_path, lwords, class_words, lrelations, label
|
||||
'lcs_score': score,
|
||||
'token_id': value['id']
|
||||
})
|
||||
# print('='*10, file_path)
|
||||
# print(vat_info)
|
||||
# logger.info('='*10, file_path)
|
||||
# logger.info(vat_info)
|
||||
# Combine VAT information and table
|
||||
vat_outputs = {k: None for k in list(single_pairs)}
|
||||
for key_name, list_potential_value in single_pairs.items():
|
||||
@ -387,7 +393,7 @@ def export_kvu_for_SDSAP(file_path, lwords, class_words, lrelations, labels=['ot
|
||||
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
|
||||
for cell in single_item:
|
||||
header_name, score, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=True)
|
||||
# print(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}")
|
||||
# logger.info(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}")
|
||||
if header_name in list(item.keys()):
|
||||
item[header_name].append({
|
||||
'content': cell['text'],
|
||||
@ -436,7 +442,7 @@ def export_kvu_for_SDSAP(file_path, lwords, class_words, lrelations, labels=['ot
|
||||
for pair in outputs['single']:
|
||||
for key_name, value in pair.items():
|
||||
key_name, score, proceessed_text = ap_standardizer(key_name, threshold=0.8, header=False)
|
||||
# print(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
# logger.info(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs):
|
||||
single_pairs[key_name].append({
|
||||
|
@ -5,12 +5,19 @@ import sys, os
|
||||
cur_dir = os.path.dirname(__file__)
|
||||
sys.path.append(os.path.join(os.path.dirname(cur_dir), "ocr-engine"))
|
||||
from src.ocr import OcrEngine
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_ocr_engine() -> OcrEngine:
|
||||
print("[INFO] Loading engine...")
|
||||
logger.info("[INFO] Loading engine...")
|
||||
engine = OcrEngine()
|
||||
print("[INFO] Engine loaded")
|
||||
logger.info("[INFO] Engine loaded")
|
||||
return engine
|
||||
|
||||
def process_img(img: Union[str, np.ndarray], save_dir_or_path: str, engine: OcrEngine, export_img: bool) -> None:
|
||||
|
@ -22,6 +22,13 @@ from utils.kvu_dictionary import (
|
||||
ap_dictionary,
|
||||
manulife_dictionary
|
||||
)
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@ -29,20 +36,20 @@ def create_dir(save_dir=''):
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
# else:
|
||||
# print("DIR already existed.")
|
||||
# print('Save dir : {}'.format(save_dir))
|
||||
# logger.info("DIR already existed.")
|
||||
# logger.info('Save dir : {}'.format(save_dir))
|
||||
|
||||
def convert_pdf2img(pdf_dir, save_dir):
|
||||
pdf_files = glob.glob(f'{pdf_dir}/*.pdf')
|
||||
print('No. pdf files:', len(pdf_files))
|
||||
print(pdf_files)
|
||||
logger.info('No. pdf files:', len(pdf_files))
|
||||
logger.info(pdf_files)
|
||||
|
||||
for file in tqdm(pdf_files):
|
||||
pdf2img(file, save_dir, n_pages=-1, return_fname=False)
|
||||
# pages = convert_from_path(file, 500)
|
||||
# for i, page in enumerate(pages):
|
||||
# page.save(os.path.join(save_dir, os.path.basename(file).replace('.pdf', f'_{i}.jpg')), 'JPEG')
|
||||
print('Done!!!')
|
||||
logger.info('Done!!!')
|
||||
|
||||
def pdf2img(pdf_path, save_dir, n_pages=-1, return_fname=False):
|
||||
file_names = []
|
||||
@ -296,7 +303,7 @@ def matched_wordgroup_relations(word_groups:dict, lrelations: list) -> list:
|
||||
try:
|
||||
outputs.append([word_groups[wg_from], word_groups[wg_to]])
|
||||
except:
|
||||
print('Not valid pair:', wg_from, wg_to)
|
||||
logger.info('Not valid pair:', wg_from, wg_to)
|
||||
return outputs
|
||||
|
||||
def get_single_entity(word_groups: dict, lrelations: list) -> list:
|
||||
@ -324,7 +331,7 @@ def export_kvu_outputs(file_path, lwords, lbboxes, class_words, lrelations, labe
|
||||
triplet_pairs = []
|
||||
single_pairs = []
|
||||
table = []
|
||||
# print('key2values_relations', key2values_relations)
|
||||
# logger.info('key2values_relations', key2values_relations)
|
||||
for key_group_id, list_value_group_ids in key2values_relations.items():
|
||||
if len(list_value_group_ids) == 0: continue
|
||||
elif (len(list_value_group_ids) == 1) and (list_value_group_ids[0] not in list(header_value.keys())) and (key_group_id not in list(header_key.keys())):
|
||||
@ -443,7 +450,7 @@ def export_kvu_for_all(file_path, lwords, lbboxes, class_words, lrelations, labe
|
||||
header_list = {cell['header']: cell['header_bbox'] for row in raw_outputs['table'] for cell in row}
|
||||
if header_list:
|
||||
header_list = dict(sorted(header_list.items(), key=lambda x: int(x[1][0])))
|
||||
print("Header_list:", header_list.keys())
|
||||
logger.info("Header_list:", header_list.keys())
|
||||
|
||||
for row in raw_outputs["table"]:
|
||||
item = {header: None for header in list(header_list.keys())}
|
||||
@ -517,7 +524,7 @@ def export_kvu_for_manulife(
|
||||
header_list = {cell['header']: cell['header_bbox'] for row in raw_outputs['table'] for cell in row}
|
||||
if header_list:
|
||||
header_list = dict(sorted(header_list.items(), key=lambda x: int(x[1][0])))
|
||||
# print("Header_list:", header_list.keys())
|
||||
# logger.info("Header_list:", header_list.keys())
|
||||
|
||||
for row in raw_outputs["table"]:
|
||||
item = {header: None for header in list(header_list.keys())}
|
||||
@ -539,7 +546,7 @@ def get_vat_table_information(outputs):
|
||||
for single_item in outputs['table']:
|
||||
headers = [item['header'] for sublist in outputs['table'] for item in sublist if 'header' in item]
|
||||
item = {k: [] for k in headers}
|
||||
print(item)
|
||||
logger.info(item)
|
||||
for cell in single_item:
|
||||
# header_name, score, proceessed_text = vat_standardizer(cell['header'], threshold=0.75, header=True)
|
||||
# if header_name in list(item.keys()):
|
||||
@ -565,7 +572,7 @@ def get_vat_table_information(outputs):
|
||||
# if item["Mặt hàng"] == None:
|
||||
# continue
|
||||
table.append(item)
|
||||
print(table)
|
||||
logger.info(table)
|
||||
return table
|
||||
|
||||
def get_vat_information(outputs):
|
||||
@ -574,7 +581,7 @@ def get_vat_information(outputs):
|
||||
for pair in outputs['single']:
|
||||
for raw_key_name, value in pair.items():
|
||||
key_name, score, proceessed_text = vat_standardizer(raw_key_name, threshold=0.8, header=False)
|
||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs.keys()):
|
||||
single_pairs[key_name].append({
|
||||
@ -588,7 +595,7 @@ def get_vat_information(outputs):
|
||||
for key, value_list in triplet.items():
|
||||
if len(value_list) == 1:
|
||||
key_name, score, proceessed_text = vat_standardizer(key, threshold=0.8, header=False)
|
||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs.keys()):
|
||||
single_pairs[key_name].append({
|
||||
@ -600,7 +607,7 @@ def get_vat_information(outputs):
|
||||
|
||||
for pair in value_list:
|
||||
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
|
||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs.keys()):
|
||||
single_pairs[key_name].append({
|
||||
@ -613,7 +620,7 @@ def get_vat_information(outputs):
|
||||
for table_row in outputs['table']:
|
||||
for pair in table_row:
|
||||
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
|
||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs.keys()):
|
||||
single_pairs[key_name].append({
|
||||
@ -674,7 +681,7 @@ def export_kvu_for_VAT_invoice(file_path, lwords, class_words, lrelations, label
|
||||
vat_outputs['table'] = table
|
||||
|
||||
write_to_json(file_path, vat_outputs)
|
||||
print(vat_outputs)
|
||||
logger.info(vat_outputs)
|
||||
return vat_outputs
|
||||
|
||||
|
||||
@ -686,7 +693,7 @@ def get_ap_table_information(outputs):
|
||||
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
|
||||
for cell in single_item:
|
||||
header_name, score, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=True)
|
||||
# print(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}")
|
||||
# logger.info(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}")
|
||||
if header_name in list(item.keys()):
|
||||
item[header_name].append({
|
||||
'content': cell['text'],
|
||||
@ -740,7 +747,7 @@ def get_ap_information(outputs):
|
||||
for pair in outputs['single']:
|
||||
for raw_key_name, value in pair.items():
|
||||
key_name, score, proceessed_text = ap_standardizer(raw_key_name, threshold=0.8, header=False)
|
||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs):
|
||||
single_pairs[key_name].append({
|
||||
@ -763,7 +770,7 @@ def get_ap_information(outputs):
|
||||
|
||||
if all(v is not None for k, v in pair.items()) and is_product_info == True:
|
||||
key_name, score, proceessed_text = ap_standardizer(pair['key']['text'], threshold=0.8, header=False)
|
||||
# print(f"{pair['key']['text']} ==> {proceessed_text} ==> {key_name} : {score} - {pair['value']['text']}")
|
||||
# logger.info(f"{pair['key']['text']} ==> {proceessed_text} ==> {key_name} : {score} - {pair['value']['text']}")
|
||||
|
||||
if key_name in list(single_pairs):
|
||||
single_pairs[key_name].append({
|
||||
@ -778,7 +785,7 @@ def get_ap_information(outputs):
|
||||
for key_name, list_potential_value in single_pairs.items():
|
||||
if len(list_potential_value) == 0: continue
|
||||
if key_name == "imei_number":
|
||||
# print('list_potential_value', list_potential_value)
|
||||
# logger.info('list_potential_value', list_potential_value)
|
||||
# ap_outputs[key_name] = [v['content'] for v in list_potential_value if v['content'].replace(' ', '').isdigit() and len(v['content'].replace(' ', '')) > 5]
|
||||
ap_outputs[key_name] = []
|
||||
for v in list_potential_value:
|
||||
|
@ -1,3 +1,12 @@
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Word():
|
||||
def __init__(self, text="",image=None, conf_detect=0.0, conf_cls=0.0, bndbox = [-1,-1,-1,-1], kie_label =""):
|
||||
self.type = "word"
|
||||
@ -43,7 +52,7 @@ class Word_group():
|
||||
if word.text != "✪":
|
||||
for w in self.list_words:
|
||||
if word.word_id == w.word_id:
|
||||
print("Word id collision")
|
||||
logger.info("Word id collision")
|
||||
return False
|
||||
word.word_group_id = self.word_group_id #
|
||||
word.line_id = self.line_id
|
||||
@ -92,7 +101,7 @@ class Line():
|
||||
if word_group.list_words is not None:
|
||||
for wg in self.list_word_groups:
|
||||
if word_group.word_group_id == wg.word_group_id:
|
||||
print("Word_group id collision")
|
||||
logger.info("Word_group id collision")
|
||||
return False
|
||||
|
||||
self.list_word_groups.append(word_group)
|
||||
@ -176,7 +185,6 @@ def words_to_lines(words, check_special_lines=True): #words is list of Word inst
|
||||
new_line.merge_word(word)
|
||||
lines.append(new_line)
|
||||
|
||||
# print(len(lines))
|
||||
#sort line from top to bottom according top coordinate
|
||||
lines.sort(key = lambda x: x.boundingbox[1])
|
||||
|
||||
@ -189,7 +197,6 @@ def words_to_lines(words, check_special_lines=True): #words is list of Word inst
|
||||
continue
|
||||
#left, top ,right, bottom
|
||||
line_width = lines[i].boundingbox[2] - lines[i].boundingbox[0] # right - left
|
||||
# print("line_width",line_width)
|
||||
lines[i].list_word_groups.sort(key = lambda x: x.boundingbox[0]) #sort word in lines from left to right
|
||||
|
||||
#update text for lines after sorting
|
||||
|
@ -4,6 +4,15 @@ import string
|
||||
import copy
|
||||
from utils.kvu_dictionary import vat_dictionary, ap_dictionary, manulife_dictionary, DKVU2XML
|
||||
from word2line import Word, words_to_lines
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
nltk.download('words')
|
||||
words = set(nltk.corpus.words.words())
|
||||
|
||||
@ -32,7 +41,6 @@ def remove_punctuation(text):
|
||||
|
||||
def remove_accents(input_str, s0, s1):
|
||||
s = ''
|
||||
# print input_str.encode('utf-8')
|
||||
for c in input_str:
|
||||
if c in s1:
|
||||
s += s0[s1.index(c)]
|
||||
@ -44,7 +52,6 @@ def remove_spaces(text):
|
||||
return text.replace(' ', '')
|
||||
|
||||
def preprocessing(text: str):
|
||||
# text = remove_english_words(text) if table else text
|
||||
text = remove_punctuation(text)
|
||||
text = remove_accents(text, s0, s1)
|
||||
text = remove_spaces(text)
|
||||
@ -184,7 +191,7 @@ def post_process_for_item(item: dict) -> dict:
|
||||
elif mis_key[0] == check_keys[2]:
|
||||
item[mis_key[0]] = (convert_format_number(item[check_keys[0]]) * convert_format_number(item[check_keys[1]])).__str__()
|
||||
except Exception as e:
|
||||
print("Cannot post process this item with error:", e)
|
||||
logger.error("Cannot post process this item with error:", e)
|
||||
return item
|
||||
|
||||
|
||||
@ -280,9 +287,9 @@ def get_string_with_word2line(lwords: list, lbboxes: list):
|
||||
string_after_word2line = ' '.join(list_sorted_words)
|
||||
|
||||
if string_from_model != string_after_word2line:
|
||||
print("[Warning] Word group from model is different with word2line module")
|
||||
print("Model: ", ' '.join(unique_list))
|
||||
print("Word2line: ", ' '.join(list_sorted_words))
|
||||
logger.warning("[Warning] Word group from model is different with word2line module")
|
||||
logger.warning("Model: ", ' '.join(unique_list))
|
||||
logger.warning("Word2line: ", ' '.join(list_sorted_words))
|
||||
|
||||
return string_after_word2line
|
||||
|
||||
|
@ -49,10 +49,8 @@ def predict(image_url):
|
||||
"confidence": output[key]['conf']
|
||||
}
|
||||
output_dict['fields'].append(field)
|
||||
print(output_dict)
|
||||
return output_dict
|
||||
|
||||
if __name__ == "__main__":
|
||||
image_url = "/mnt/ssd1T/hoanglv/Projects/KIE/sdsvkie/demos/2022_07_25 farewell lunch.jpg"
|
||||
output = predict(image_url)
|
||||
print(output)
|
@ -60,18 +60,12 @@ def predict_fi(page_numb, image_url):
|
||||
output_kie = {
|
||||
field_name: field_item['value'] for field_name, field_item in output.items()
|
||||
}
|
||||
# print("Hoangggggggggggggggggggggggggggggggggggggggggggggg")
|
||||
# print(output_kie)
|
||||
|
||||
|
||||
#Phan cua Tuan
|
||||
kvu_result, _ = Predictor_KVU(image_url, save_dir, predictor, processor)
|
||||
# print("TuanNnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn")
|
||||
# print(kvu_result)
|
||||
# if kvu_result['imei_number'] == None and kvu_result['serial_number'] == None:
|
||||
return kvu_result, output_kie
|
||||
|
||||
if __name__ == "__main__":
|
||||
image_url = "/mnt/hdd2T/dxtan/TannedCung/OCR/workspace/Kie_Invoice_AP/tmp_image/{image_url}.jpg"
|
||||
output = predict_fi(0, image_url)
|
||||
print(output)
|
@ -69,7 +69,6 @@ def predict(page_numb, image_url):
|
||||
"page": page_numb
|
||||
}
|
||||
output_dict['fields'].append(field)
|
||||
print(output_dict)
|
||||
return output_dict
|
||||
|
||||
# if kvu_result['imei_number'] == None and kvu_result['serial_number'] == None:
|
||||
@ -143,4 +142,3 @@ def predict(page_numb, image_url):
|
||||
if __name__ == "__main__":
|
||||
image_url = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
|
||||
output = predict(0, image_url)
|
||||
print(output)
|
@ -1,106 +0,0 @@
|
||||
1113 773 1220 825 BEST
|
||||
1243 759 1378 808 DENKI
|
||||
1410 752 1487 799 (S)
|
||||
1430 707 1515 748 TAX
|
||||
1511 745 1598 790 PTE
|
||||
1542 700 1725 740 TNVOICE
|
||||
1618 742 1706 783 LTD
|
||||
1783 725 1920 773 FUNAN
|
||||
1943 723 2054 767 MALL
|
||||
1434 797 1576 843 WORTH
|
||||
1599 785 1760 831 BRIDGE
|
||||
1784 778 1846 822 RD
|
||||
1277 846 1632 897 #02-16/#03-1
|
||||
1655 832 1795 877 FUNAN
|
||||
1817 822 1931 869 MALL
|
||||
1272 897 1518 956 S(179105)
|
||||
1548 890 1655 943 TEL:
|
||||
1686 877 1911 928 69046183
|
||||
1247 1011 1334 1068 GST
|
||||
1358 1006 1447 1059 REG
|
||||
1360 1063 1449 1115 RCB
|
||||
1473 1003 1575 1055 NO.:
|
||||
1474 1059 1555 1110 NO.
|
||||
1595 1042 1868 1096 198202199E
|
||||
1607 985 1944 1040 M2-0053813-7
|
||||
1056 1134 1254 1194 Opening
|
||||
1276 1127 1391 1181 Hrs:
|
||||
1425 1112 1647 1170 10:00:00
|
||||
1672 1102 1735 1161 AN
|
||||
1755 1101 1819 1157 to
|
||||
1846 1090 2067 1147 10:00:00
|
||||
2090 1080 2156 1141 PH
|
||||
1061 1308 1228 1366 Staff:
|
||||
1258 1300 1378 1357 3296
|
||||
1710 1283 1880 1337 Trans:
|
||||
1936 1266 2192 1322 262152554
|
||||
1060 1372 1201 1429 Date:
|
||||
1260 1358 1494 1419 22-03-23
|
||||
1540 1344 1664 1409 9:05
|
||||
1712 1339 1856 1407 Slip:
|
||||
1917 1328 2196 1387 2000130286
|
||||
1124 1487 1439 1545 SALESPERSON
|
||||
1465 1477 1601 1537 CODE.
|
||||
1633 1471 1752 1530 6043
|
||||
1777 1462 2004 1519 HUHAHHAD
|
||||
2032 1451 2177 1509 RAZIH
|
||||
1070 1558 1187 1617 Item
|
||||
1211 1554 1276 1615 No
|
||||
1439 1542 1585 1601 Price
|
||||
1750 1530 1841 1597 Qty
|
||||
1951 1517 2120 1579 Amount
|
||||
1076 1683 1276 1741 ANDROID
|
||||
1304 1673 1477 1733 TABLET
|
||||
1080 1746 1280 1804 2105976
|
||||
1509 1729 1705 1784 SAMSUNG
|
||||
1734 1719 1931 1776 SH-P613
|
||||
1964 1709 2101 1768 128GB
|
||||
1082 1809 1285 1869 SM-P613
|
||||
1316 1802 1454 1860 12838
|
||||
1429 1859 1600 1919 518.00
|
||||
1481 1794 1596 1855 WIFI
|
||||
1622 1790 1656 1850 G
|
||||
1797 1845 1824 1904 1
|
||||
1993 1832 2165 1892 518.00
|
||||
1088 1935 1347 1995 PROMOTION
|
||||
1091 2000 1294 2062 2105664
|
||||
1520 1983 1717 2039 SAMSUNG
|
||||
1743 1963 2106 2030 F-Sam-Redeen
|
||||
1439 2111 1557 2173 0.00
|
||||
1806 2095 1832 2156 1
|
||||
2053 2081 2174 2144 0.00
|
||||
1106 2248 1250 2312 Total
|
||||
1974 2206 2146 2266 518.00
|
||||
1107 2312 1204 2377 UOB
|
||||
1448 2291 1567 2355 CARD
|
||||
1978 2268 2147 2327 518.00
|
||||
1253 2424 1375 2497 GST%
|
||||
1456 2411 1655 2475 Net.Amt
|
||||
1818 2393 1912 2460 GST
|
||||
2023 2387 2192 2445 Amount
|
||||
1106 2494 1231 2560 GST8
|
||||
1486 2472 1661 2537 479.63
|
||||
1770 2458 1916 2523 38.37
|
||||
2027 2448 2203 2511 518.00
|
||||
1553 2601 1699 2666 THANK
|
||||
1721 2592 1821 2661 YOU
|
||||
1436 2678 1616 2749 please
|
||||
1644 2682 1764 2732 come
|
||||
1790 2660 1942 2729 again
|
||||
1191 2862 1391 2931 Those
|
||||
1426 2870 2018 2945 facebook.com
|
||||
1565 2809 1690 2884 join
|
||||
1709 2816 1777 2870 us
|
||||
1799 2811 1868 2865 on
|
||||
1838 2946 2024 3003 com .89
|
||||
1533 3006 2070 3088 ar.com/askbe
|
||||
1300 3326 1659 3446 That's
|
||||
1696 3308 1905 3424 not
|
||||
1937 3289 2131 3408 all!
|
||||
1450 3511 1633 3573 SCAN
|
||||
1392 3589 1489 3645 QR
|
||||
1509 3577 1698 3635 CODE
|
||||
1321 3656 1370 3714 &
|
||||
1517 3638 1768 3699 updates
|
||||
1643 3882 1769 3932 Scan
|
||||
1789 3868 1859 3926 Me
|
Binary file not shown.
Before Width: | Height: | Size: 1.1 MiB |
@ -100,7 +100,6 @@ def word_to_line(list_words):
|
||||
"""
|
||||
texts, boundingboxes = [], []
|
||||
for line in list_words:
|
||||
print(line.text)
|
||||
if line.text == "":
|
||||
continue
|
||||
else:
|
||||
|
@ -6,6 +6,13 @@ det_ckpt = "/models/sdsvtd/hub/wild_receipt_finetune_weights_c_lite.pth"
|
||||
cls_ckpt = "satrn-lite-general-pretrain-20230106"
|
||||
|
||||
engine = OcrEngineForYoloX_Invoice(det_ckpt, cls_ckpt)
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def ocr_predict(img):
|
||||
@ -24,7 +31,7 @@ def ocr_predict(img):
|
||||
return list_lines
|
||||
# return lbboxes, lwords
|
||||
except AssertionError as e:
|
||||
print(e)
|
||||
logger.info(e)
|
||||
list_lines = []
|
||||
return list_lines
|
||||
|
||||
|
@ -9,16 +9,23 @@ sys.path.append(cur_dir)
|
||||
from modules.sdsvkvu import load_engine, process_img
|
||||
from modules.ocr_engine import OcrEngine
|
||||
from configs.manulife import device, ocr_cfg, kvu_cfg
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_ocr_engine(opt) -> OcrEngine:
|
||||
print("[INFO] Loading engine...")
|
||||
logger.info("[INFO] Loading engine...")
|
||||
engine = OcrEngine(**opt)
|
||||
print("[INFO] Engine loaded")
|
||||
logger.info("[INFO] Engine loaded")
|
||||
return engine
|
||||
|
||||
|
||||
print("OCR engine configfs: \n", ocr_cfg)
|
||||
print("KVU configfs: \n", kvu_cfg)
|
||||
logger.info("OCR engine configfs: \n", ocr_cfg)
|
||||
logger.info("KVU configfs: \n", kvu_cfg)
|
||||
|
||||
ocr_engine = load_ocr_engine(ocr_cfg)
|
||||
kvu_cfg['ocr_engine'] = ocr_engine
|
||||
@ -86,7 +93,7 @@ def predict(page_numb, image_url):
|
||||
"page": page_numb
|
||||
}
|
||||
output_dict['fields'].append(field)
|
||||
print(output_dict)
|
||||
logger.info(output_dict)
|
||||
return output_dict
|
||||
|
||||
|
||||
@ -95,4 +102,4 @@ def predict(page_numb, image_url):
|
||||
if __name__ == "__main__":
|
||||
image_url = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
|
||||
output = predict(0, image_url)
|
||||
print(output)
|
||||
logger.info(output)
|
@ -14,9 +14,17 @@ nltk.data.path.append(os.path.join((os.getcwd() + '/nltk_data')))
|
||||
|
||||
from modules.sdsvkvu import load_engine, process_img
|
||||
from configs.sdsap_sbt import device, ocr_cfg, kvu_cfg
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
|
||||
print("OCR engine configfs: \n", ocr_cfg)
|
||||
print("KVU configfs: \n", kvu_cfg)
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger.info("OCR engine configfs: \n", ocr_cfg)
|
||||
logger.info("KVU configfs: \n", kvu_cfg)
|
||||
|
||||
# ocr_engine = load_ocr_engine(ocr_cfg)
|
||||
# kvu_cfg['ocr_engine'] = ocr_engine
|
||||
@ -40,7 +48,7 @@ def sbt_predict(image_url, engine, metadata={}) -> None:
|
||||
query_params = urllib.parse.parse_qs(parsed_url.query)
|
||||
file_name = query_params['file_name'][0]
|
||||
except Exception as e:
|
||||
print(f"[ERROR]: Error extracting file name from url: {image_url}")
|
||||
logger.info(f"[ERROR]: Error extracting file name from url: {image_url}")
|
||||
file_name = f"{uuid.uuid4()}.jpg"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
# image_path = os.path.join(save_dir, f"{image_url}.jpg")
|
||||
@ -103,4 +111,4 @@ def predict(page_numb, image_url, metadata={}):
|
||||
if __name__ == "__main__":
|
||||
image_url = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
|
||||
output = predict(0, image_url)
|
||||
print(output)
|
||||
logger.info(output)
|
@ -1,81 +0,0 @@
|
||||
from celery import Celery
|
||||
import base64
|
||||
import environ
|
||||
env = environ.Env(
|
||||
DEBUG=(bool, False)
|
||||
)
|
||||
|
||||
class CeleryConnector:
|
||||
task_routes = {
|
||||
"process_id_result": {"queue": "id_card_rs"},
|
||||
"process_driver_license_result": {"queue": "driver_license_rs"},
|
||||
"process_invoice_result": {"queue": "invoice_rs"},
|
||||
"process_ocr_with_box_result": {"queue": "ocr_with_box_rs"},
|
||||
"process_template_matching_result": {"queue": "template_matching_rs"},
|
||||
# mock task
|
||||
"process_id": {"queue": "id_card"},
|
||||
"process_driver_license": {"queue": "driver_license"},
|
||||
"process_invoice": {"queue": "invoice"},
|
||||
"process_ocr_with_box": {"queue": "ocr_with_box"},
|
||||
"process_template_matching": {"queue": "template_matching"},
|
||||
}
|
||||
app = Celery(
|
||||
"postman",
|
||||
broker=env.str("CELERY_BROKER", "amqp://test:test@rabbitmq:5672"),
|
||||
broker_transport_options={'confirm_publish': False},
|
||||
)
|
||||
|
||||
def process_id_result(self, args):
|
||||
return self.send_task("process_id_result", args)
|
||||
|
||||
def process_driver_license_result(self, args):
|
||||
return self.send_task("process_driver_license_result", args)
|
||||
|
||||
def process_invoice_result(self, args):
|
||||
return self.send_task("process_invoice_result", args)
|
||||
|
||||
def process_ocr_with_box_result(self, args):
|
||||
return self.send_task("process_ocr_with_box_result", args)
|
||||
|
||||
def process_template_matching_result(self, args):
|
||||
return self.send_task("process_template_matching_result", args)
|
||||
|
||||
def process_id(self, args):
|
||||
return self.send_task("process_id", args)
|
||||
|
||||
def process_driver_license(self, args):
|
||||
return self.send_task("process_driver_license", args)
|
||||
|
||||
def process_invoice(self, args):
|
||||
return self.send_task("process_invoice", args)
|
||||
|
||||
def process_ocr_with_box(self, args):
|
||||
return self.send_task("process_ocr_with_box", args)
|
||||
|
||||
def process_template_matching(self, args):
|
||||
return self.send_task("process_template_matching", args)
|
||||
|
||||
def send_task(self, name=None, args=None):
|
||||
if name not in self.task_routes or "queue" not in self.task_routes[name]:
|
||||
return self.app.send_task(name, args)
|
||||
|
||||
return self.app.send_task(name, args, queue=self.task_routes[name]["queue"])
|
||||
|
||||
|
||||
def main():
|
||||
rq_id = 345
|
||||
file_names = "abc.jpg"
|
||||
list_data = []
|
||||
|
||||
with open("/home/sds/thucpd/aicr-2022/abc.jpg", "rb") as fs:
|
||||
encoded_string = base64.b64encode(fs.read()).decode("utf-8")
|
||||
list_data.append(encoded_string)
|
||||
|
||||
c_connector = CeleryConnector()
|
||||
a = c_connector.process_id(args=(rq_id, list_data, file_names))
|
||||
|
||||
print(a)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,5 +1,7 @@
|
||||
from celery import Celery
|
||||
import environ
|
||||
from utils.logging.local_storage import get_current_trace_id
|
||||
|
||||
env = environ.Env(
|
||||
DEBUG=(bool, False)
|
||||
)
|
||||
@ -53,5 +55,6 @@ class CeleryConnector:
|
||||
def send_task(self, name=None, args=None):
|
||||
if name not in self.task_routes or "queue" not in self.task_routes[name]:
|
||||
return self.app.send_task(name, args)
|
||||
|
||||
trace_id = get_current_trace_id()
|
||||
args += (trace_id,) # add trace_id to args then remove before start
|
||||
return self.app.send_task(name, args, queue=self.task_routes[name]["queue"])
|
@ -1,220 +0,0 @@
|
||||
from celery_worker.worker import app
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
@app.task(name="process_id")
|
||||
def process_id(rq_id, sub_id, folder_name, list_url, user_id):
|
||||
from common.serve_model import predict
|
||||
from celery_worker.client_connector import CeleryConnector
|
||||
|
||||
c_connector = CeleryConnector()
|
||||
try:
|
||||
result = predict(rq_id, sub_id, folder_name, list_url, user_id, infer_name="id_card")
|
||||
print(result)
|
||||
result = {
|
||||
"status": 200,
|
||||
"content": result,
|
||||
"message": "Success",
|
||||
}
|
||||
c_connector.process_id_result((rq_id, result))
|
||||
return {"rq_id": rq_id}
|
||||
# if image_croped is not None:
|
||||
# if result["data"] == []:
|
||||
# result = {
|
||||
# "status": 404,
|
||||
# "content": {},
|
||||
# }
|
||||
# c_connector.process_id_result((rq_id, result, None))
|
||||
# return {"rq_id": rq_id}
|
||||
# else:
|
||||
# result = {
|
||||
# "status": 200,
|
||||
# "content": result,
|
||||
# "message": "Success",
|
||||
# }
|
||||
# c_connector.process_id_result((rq_id, result))
|
||||
# return {"rq_id": rq_id}
|
||||
# elif image_croped is None:
|
||||
# result = {
|
||||
# "status": 404,
|
||||
# "content": {},
|
||||
# }
|
||||
# c_connector.process_id_result((rq_id, result, None))
|
||||
# return {"rq_id": rq_id}
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
result = {
|
||||
"status": 404,
|
||||
"content": {},
|
||||
}
|
||||
c_connector.process_id_result((rq_id, result, None))
|
||||
return {"rq_id": rq_id}
|
||||
|
||||
|
||||
@app.task(name="process_driver_license")
|
||||
def process_driver_license(rq_id, sub_id, folder_name, list_url, user_id):
|
||||
from common.serve_model import predict
|
||||
from celery_worker.client_connector import CeleryConnector
|
||||
|
||||
c_connector = CeleryConnector()
|
||||
try:
|
||||
result = predict(rq_id, sub_id, folder_name, list_url, user_id, infer_name="driving_license")
|
||||
result = {
|
||||
"status": 200,
|
||||
"content": result,
|
||||
"message": "Success",
|
||||
}
|
||||
c_connector.process_driver_license_result((rq_id, result))
|
||||
return {"rq_id": rq_id}
|
||||
# result, image_croped = predict(str(url), "driving_license")
|
||||
# if image_croped is not None:
|
||||
# if result["data"] == []:
|
||||
# result = {
|
||||
# "status": 404,
|
||||
# "content": {},
|
||||
# }
|
||||
# c_connector.process_driver_license_result((rq_id, result, None))
|
||||
# return {"rq_id": rq_id}
|
||||
# else:
|
||||
# result = {
|
||||
# "status": 200,
|
||||
# "content": result,
|
||||
# "message": "Success",
|
||||
# }
|
||||
# path_image_croped = "/app/media/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id)
|
||||
# cv2.imwrite("/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id), image_croped)
|
||||
# c_connector.process_driver_license_result((rq_id, result, path_image_croped))
|
||||
# return {"rq_id": rq_id}
|
||||
# elif image_croped is None:
|
||||
# result = {
|
||||
# "status": 404,
|
||||
# "content": {},
|
||||
# }
|
||||
# c_connector.process_driver_license_result((rq_id, result, None))
|
||||
# return {"rq_id": rq_id}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
result = {
|
||||
"status": 404,
|
||||
"content": {},
|
||||
}
|
||||
c_connector.process_driver_license_result((rq_id, result, None))
|
||||
return {"rq_id": rq_id}
|
||||
|
||||
|
||||
@app.task(name="process_template_matching")
|
||||
def process_template_matching(rq_id, sub_id, folder_name, url, tmp_json, user_id):
|
||||
from TemplateMatching.src.ocr_master import Extractor
|
||||
from celery_worker.client_connector import CeleryConnector
|
||||
import urllib
|
||||
|
||||
c_connector = CeleryConnector()
|
||||
extractor = Extractor()
|
||||
try:
|
||||
req = urllib.request.urlopen(url)
|
||||
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, -1)
|
||||
imgs = [img]
|
||||
image_aliged = extractor.image_alige(imgs, tmp_json)
|
||||
if image_aliged is None:
|
||||
result = {
|
||||
"status": 401,
|
||||
"content": "Image is not match with template",
|
||||
}
|
||||
c_connector.process_template_matching_result(
|
||||
(rq_id, result, None)
|
||||
)
|
||||
return {"rq_id": rq_id}
|
||||
else:
|
||||
output = extractor.extract_information(
|
||||
image_aliged, tmp_json
|
||||
)
|
||||
path_image_croped = "/app/media/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id)
|
||||
cv2.imwrite("/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id), image_aliged)
|
||||
if output == {}:
|
||||
result = {"status": 404, "content": {}}
|
||||
c_connector.process_template_matching_result((rq_id, result, None))
|
||||
return {"rq_id": rq_id}
|
||||
else:
|
||||
result = {
|
||||
"document_type": "template_matching",
|
||||
"fields": []
|
||||
}
|
||||
print(output)
|
||||
for field in tmp_json["fields"]:
|
||||
print(field["label"])
|
||||
field_value = {
|
||||
"label": field["label"],
|
||||
"value": output[field["label"]],
|
||||
"box": [float(num) for num in field["box"]],
|
||||
"confidence": 0.98 #TODO confidence
|
||||
}
|
||||
result["fields"].append(field_value)
|
||||
|
||||
print(result)
|
||||
result = {"status": 200, "content": result}
|
||||
c_connector.process_template_matching_result(
|
||||
(rq_id, result, path_image_croped)
|
||||
)
|
||||
return {"rq_id": rq_id}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
result = {"status": 404, "content": {}}
|
||||
c_connector.process_template_matching_result((rq_id, result, None))
|
||||
return {"rq_id": rq_id}
|
||||
|
||||
|
||||
# @app.task(name="process_invoice")
|
||||
# def process_invoice(rq_id, url):
|
||||
# from celery_worker.client_connector import CeleryConnector
|
||||
# from Kie_Hoanglv.prediction import predict
|
||||
|
||||
# c_connector = CeleryConnector()
|
||||
# try:
|
||||
# print(url)
|
||||
# result = predict(str(url))
|
||||
# hoadon = {"status": 200, "content": result, "message": "Success"}
|
||||
# c_connector.process_invoice_result((rq_id, hoadon))
|
||||
# return {"rq_id": rq_id}
|
||||
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
# hoadon = {"status": 404, "content": {}}
|
||||
# c_connector.process_invoice_result((rq_id, hoadon))
|
||||
# return {"rq_id": rq_id}
|
||||
|
||||
@app.task(name="process_invoice")
|
||||
def process_invoice(rq_id, list_url):
|
||||
from celery_worker.client_connector import CeleryConnector
|
||||
from common.process_pdf import compile_output
|
||||
|
||||
c_connector = CeleryConnector()
|
||||
try:
|
||||
result = compile_output(list_url)
|
||||
hoadon = {"status": 200, "content": result, "message": "Success"}
|
||||
c_connector.process_invoice_result((rq_id, hoadon))
|
||||
return {"rq_id": rq_id}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
hoadon = {"status": 404, "content": {}}
|
||||
c_connector.process_invoice_result((rq_id, hoadon))
|
||||
return {"rq_id": rq_id}
|
||||
|
||||
|
||||
@app.task(name="process_ocr_with_box")
|
||||
def process_ocr_with_box(rq_id, list_url):
|
||||
from celery_worker.client_connector import CeleryConnector
|
||||
from common.process_pdf import compile_output_ocr_base
|
||||
|
||||
c_connector = CeleryConnector()
|
||||
try:
|
||||
result = compile_output_ocr_base(list_url)
|
||||
result = {"status": 200, "content": result, "message": "Success"}
|
||||
c_connector.process_ocr_with_box_result((rq_id, result))
|
||||
return {"rq_id": rq_id}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
result = {"status": 404, "content": {}}
|
||||
c_connector.process_ocr_with_box_result((rq_id, result))
|
||||
return {"rq_id": rq_id}
|
@ -1,8 +1,16 @@
|
||||
from celery_worker.worker_fi import app
|
||||
from celery_worker.client_connector_fi import CeleryConnector
|
||||
from common.process_pdf import compile_output_sbt
|
||||
from .task_warpper import VerboseTask
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@app.task(name="process_fi_invoice")
|
||||
@app.task(base=VerboseTask,name="process_fi_invoice")
|
||||
def process_invoice(rq_id, list_url):
|
||||
from celery_worker.client_connector_fi import CeleryConnector
|
||||
from common.process_pdf import compile_output_fi
|
||||
@ -11,22 +19,22 @@ def process_invoice(rq_id, list_url):
|
||||
try:
|
||||
result = compile_output_fi(list_url)
|
||||
hoadon = {"status": 200, "content": result, "message": "Success"}
|
||||
print(hoadon)
|
||||
logger.info(hoadon)
|
||||
c_connector.process_fi_invoice_result((rq_id, hoadon))
|
||||
return {"rq_id": rq_id}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
logger.info(e)
|
||||
hoadon = {"status": 404, "content": {}}
|
||||
c_connector.process_fi_invoice_result((rq_id, hoadon))
|
||||
return {"rq_id": rq_id}
|
||||
|
||||
|
||||
@app.task(name="process_sap_invoice")
|
||||
@app.task(base=VerboseTask,name="process_sap_invoice")
|
||||
def process_sap_invoice(rq_id, list_url):
|
||||
from celery_worker.client_connector_fi import CeleryConnector
|
||||
from common.process_pdf import compile_output
|
||||
|
||||
print(list_url)
|
||||
logger.info(list_url)
|
||||
c_connector = CeleryConnector()
|
||||
try:
|
||||
result = compile_output(list_url)
|
||||
@ -34,12 +42,12 @@ def process_sap_invoice(rq_id, list_url):
|
||||
c_connector.process_sap_invoice_result((rq_id, hoadon))
|
||||
return {"rq_id": rq_id}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
logger.info(e)
|
||||
hoadon = {"status": 404, "content": {}}
|
||||
c_connector.process_sap_invoice_result((rq_id, hoadon))
|
||||
return {"rq_id": rq_id}
|
||||
|
||||
@app.task(name="process_manulife_invoice")
|
||||
@app.task(base=VerboseTask,name="process_manulife_invoice")
|
||||
def process_manulife_invoice(rq_id, list_url):
|
||||
from celery_worker.client_connector_fi import CeleryConnector
|
||||
from common.process_pdf import compile_output_manulife
|
||||
@ -48,16 +56,16 @@ def process_manulife_invoice(rq_id, list_url):
|
||||
try:
|
||||
result = compile_output_manulife(list_url)
|
||||
hoadon = {"status": 200, "content": result, "message": "Success"}
|
||||
print(hoadon)
|
||||
logger.info(hoadon)
|
||||
c_connector.process_manulife_invoice_result((rq_id, hoadon))
|
||||
return {"rq_id": rq_id}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
logger.info(e)
|
||||
hoadon = {"status": 404, "content": {}}
|
||||
c_connector.process_manulife_invoice_result((rq_id, hoadon))
|
||||
return {"rq_id": rq_id}
|
||||
|
||||
@app.task(name="process_sbt_invoice")
|
||||
@app.task(base=VerboseTask,name="process_sbt_invoice")
|
||||
def process_sbt_invoice(rq_id, list_url, metadata):
|
||||
# TODO: simply returning 200 and 404 doesn't make any sense
|
||||
c_connector = CeleryConnector()
|
||||
@ -65,12 +73,12 @@ def process_sbt_invoice(rq_id, list_url, metadata):
|
||||
result = compile_output_sbt(list_url, metadata)
|
||||
metadata['ai_inference_profile'] = result.pop("inference_profile")
|
||||
hoadon = {"status": 200, "content": result, "message": "Success"}
|
||||
print(hoadon)
|
||||
logger.info(hoadon)
|
||||
c_connector.process_sbt_invoice_result((rq_id, hoadon, metadata))
|
||||
return {"rq_id": rq_id}
|
||||
except Exception as e:
|
||||
print(f"[ERROR]: Failed to extract invoice: {e}")
|
||||
print(e)
|
||||
logger.info(f"[ERROR]: Failed to extract invoice: {e}")
|
||||
logger.info(e)
|
||||
hoadon = {"status": 404, "content": {}}
|
||||
c_connector.process_sbt_invoice_result((rq_id, hoadon, metadata))
|
||||
return {"rq_id": rq_id}
|
20
cope2n-ai-fi/celery_worker/task_warpper.py
Normal file
20
cope2n-ai-fi/celery_worker/task_warpper.py
Normal file
@ -0,0 +1,20 @@
|
||||
from celery import Task
|
||||
from celery.utils.log import get_task_logger
|
||||
from utils.logging.local_storage import get_current_trace_id, set_current_trace_id
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
class VerboseTask(Task):
|
||||
abstract = True
|
||||
|
||||
def on_failure(self, exc, task_id, args, kwargs, einfo):
|
||||
# Task failed. What do you want to do?
|
||||
logger.error(f'FAILURE: Task: {self.name} - {task_id} | Task raised an exception: {exc}')
|
||||
|
||||
def on_success(self, retval, task_id, args, kwargs):
|
||||
logger.info(f"SUCCESS: Task: {self.name} - {task_id} | retval: {retval} | args: {args} | kwargs: {kwargs}")
|
||||
|
||||
def before_start(self, task_id, args, kwargs):
|
||||
trace_id = args[-1]
|
||||
args.pop(-1)
|
||||
set_current_trace_id(trace_id)
|
||||
logger.info(f"BEFORE_START: Task: {self.name} - {task_id} | args: {args} | kwargs: {kwargs}")
|
@ -1,41 +0,0 @@
|
||||
from celery import Celery
|
||||
from kombu import Queue, Exchange
|
||||
import environ
|
||||
env = environ.Env(
|
||||
DEBUG=(bool, False)
|
||||
)
|
||||
|
||||
app: Celery = Celery(
|
||||
"postman",
|
||||
broker= env.str("CELERY_BROKER", "amqp://test:test@rabbitmq:5672"),
|
||||
# backend="rpc://",
|
||||
include=[
|
||||
"celery_worker.mock_process_tasks",
|
||||
],
|
||||
broker_transport_options={'confirm_publish': False},
|
||||
)
|
||||
task_exchange = Exchange("default", type="direct")
|
||||
task_create_missing_queues = False
|
||||
app.conf.update(
|
||||
{
|
||||
"result_expires": 3600,
|
||||
"task_queues": [
|
||||
Queue("id_card"),
|
||||
Queue("driver_license"),
|
||||
Queue("invoice"),
|
||||
Queue("ocr_with_box"),
|
||||
Queue("template_matching"),
|
||||
],
|
||||
"task_routes": {
|
||||
"process_id": {"queue": "id_card"},
|
||||
"process_driver_license": {"queue": "driver_license"},
|
||||
"process_invoice": {"queue": "invoice"},
|
||||
"process_ocr_with_box": {"queue": "ocr_with_box"},
|
||||
"process_template_matching": {"queue": "template_matching"},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
argv = ["celery_worker.worker", "--loglevel=INFO", "--pool=solo"] # Window opts
|
||||
app.worker_main(argv)
|
@ -1,6 +1,7 @@
|
||||
from celery import Celery
|
||||
from kombu import Queue, Exchange
|
||||
import environ
|
||||
|
||||
env = environ.Env(
|
||||
DEBUG=(bool, False)
|
||||
)
|
||||
@ -13,6 +14,7 @@ app: Celery = Celery(
|
||||
],
|
||||
broker_transport_options={'confirm_publish': False},
|
||||
)
|
||||
|
||||
task_exchange = Exchange("default", type="direct")
|
||||
task_create_missing_queues = False
|
||||
app.conf.update(
|
||||
|
@ -98,4 +98,3 @@ if __name__ == "__main__":
|
||||
image_path = "/mnt/ssd1T/tuanlv/PV2-2023/common/AnyKey_Value/visualize/test1/RedInvoice_WaterPurfier_Feb_PVI_829_0.jpg"
|
||||
save_dir = "/mnt/ssd1T/tuanlv/PV2-2023/common/AnyKey_Value/visualize/test1"
|
||||
vat_outputs = predict_image(image_path, save_dir, predictor, processor)
|
||||
print('[INFO] Done')
|
||||
|
@ -7,6 +7,13 @@ sys.path.append('/mnt/ssd1T/tuanlv/02.KeyValueUnderstanding/') #TODO: ??????
|
||||
from lightning_modules.classifier_module import parse_initial_words, parse_subsequent_words, parse_relations
|
||||
from model import get_model
|
||||
from utils import load_model_weight
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KVUPredictor:
|
||||
@ -18,9 +25,9 @@ class KVUPredictor:
|
||||
self.dummy_idx = dummy_idx
|
||||
self.mode = mode
|
||||
|
||||
print('[INFO] Loading Key-Value Understanding model ...')
|
||||
logger.info('[INFO] Loading Key-Value Understanding model ...')
|
||||
self.net, cfg, self.backbone_type = self._load_model(cfg_path, ckpt_path)
|
||||
print("[INFO] Loaded model")
|
||||
logger.info("[INFO] Loaded model")
|
||||
|
||||
if mode == 3:
|
||||
self.max_window_count = cfg.train.max_window_count
|
||||
@ -39,7 +46,7 @@ class KVUPredictor:
|
||||
cfg.stage = self.mode
|
||||
backbone_type = cfg.model.backbone
|
||||
|
||||
print('[INFO] Checkpoint:', ckpt_path)
|
||||
logger.info('[INFO] Checkpoint:', ckpt_path)
|
||||
net = get_model(cfg)
|
||||
load_model_weight(net, ckpt_path)
|
||||
net.to('cuda')
|
||||
|
@ -6,7 +6,13 @@ from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||
from pytorch_lightning.plugins import DDPPlugin
|
||||
from utils.ema_callbacks import EMA
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _update_config(cfg):
|
||||
cfg.save_weight_dir = os.path.join(cfg.workspace, "checkpoints")
|
||||
@ -14,7 +20,7 @@ def _update_config(cfg):
|
||||
|
||||
# set per-gpu batch size
|
||||
num_devices = torch.cuda.device_count()
|
||||
print('No. devices:', num_devices)
|
||||
logger.info('No. devices:', num_devices)
|
||||
for mode in ["train", "val"]:
|
||||
new_batch_size = cfg[mode].batch_size // num_devices
|
||||
cfg[mode].batch_size = new_batch_size
|
||||
@ -89,15 +95,15 @@ def create_exp_dir(save_dir=''):
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
else:
|
||||
print("DIR already existed.")
|
||||
print('Experiment dir : {}'.format(save_dir))
|
||||
logger.info("DIR already existed.")
|
||||
logger.info('Experiment dir : {}'.format(save_dir))
|
||||
|
||||
def create_dir(save_dir=''):
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
else:
|
||||
print("DIR already existed.")
|
||||
print('Save dir : {}'.format(save_dir))
|
||||
logger.info("DIR already existed.")
|
||||
logger.info('Save dir : {}'.format(save_dir))
|
||||
|
||||
def load_checkpoint(ckpt_path, model, key_include):
|
||||
assert os.path.exists(ckpt_path) == True, f"Ckpt path at {ckpt_path} not exist!"
|
||||
@ -109,7 +115,7 @@ def load_checkpoint(ckpt_path, model, key_include):
|
||||
state_dict[key[4:].replace(key_include + '.', "")] = state_dict[key] # remove net.something.
|
||||
del state_dict[key]
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
print(f"Load checkpoint at {ckpt_path}")
|
||||
logger.info(f"Load checkpoint at {ckpt_path}")
|
||||
return model
|
||||
|
||||
def load_model_weight(net, pretrained_model_file):
|
||||
|
@ -6,14 +6,21 @@ import sys
|
||||
# from src.ocr import OcrEngine
|
||||
sys.path.append('/home/thucpd/thucpd/git/PV2-2023/kie-invoice/components/prediction') # TODO: ??????
|
||||
import serve_model
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# def load_ocr_engine() -> OcrEngine:
|
||||
def load_ocr_engine() -> OcrEngine:
|
||||
print("[INFO] Loading engine...")
|
||||
logger.info("[INFO] Loading engine...")
|
||||
# engine = OcrEngine()
|
||||
engine = serve_model.engine
|
||||
print("[INFO] Engine loaded")
|
||||
logger.info("[INFO] Engine loaded")
|
||||
return engine
|
||||
|
||||
def process_img(img: Union[str, np.ndarray], save_dir_or_path: str, engine: OcrEngine, export_img: bool) -> None:
|
||||
|
@ -10,25 +10,31 @@ from pdf2image import convert_from_path
|
||||
from dicttoxml import dicttoxml
|
||||
from word_preprocess import vat_standardizer, get_string, ap_standardizer, post_process_for_item
|
||||
from utils.kvu_dictionary import vat_dictionary, ap_dictionary
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_dir(save_dir=''):
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
else:
|
||||
print("DIR already existed.")
|
||||
print('Save dir : {}'.format(save_dir))
|
||||
logger.info("DIR already existed.")
|
||||
logger.info('Save dir : {}'.format(save_dir))
|
||||
|
||||
def pdf2image(pdf_dir, save_dir):
|
||||
pdf_files = glob.glob(f'{pdf_dir}/*.pdf')
|
||||
print('No. pdf files:', len(pdf_files))
|
||||
logger.info('No. pdf files:', len(pdf_files))
|
||||
|
||||
for file in tqdm(pdf_files):
|
||||
pages = convert_from_path(file, 500)
|
||||
for i, page in enumerate(pages):
|
||||
page.save(os.path.join(save_dir, os.path.basename(file).replace('.pdf', f'_{i}.jpg')), 'JPEG')
|
||||
print('Done!!!')
|
||||
logger.info('Done!!!')
|
||||
|
||||
def xyxy2xywh(bbox):
|
||||
return [
|
||||
@ -246,7 +252,7 @@ def matched_wordgroup_relations(word_groups:dict, lrelations: list) -> list:
|
||||
try:
|
||||
outputs.append([word_groups[wg_from], word_groups[wg_to]])
|
||||
except Exception as e:
|
||||
print('Not valid pair:', wg_from, wg_to)
|
||||
logger.info('Not valid pair:', wg_from, wg_to)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -264,7 +270,7 @@ def export_kvu_outputs(file_path, lwords, class_words, lrelations, labels=['othe
|
||||
triplet_pairs = []
|
||||
single_pairs = []
|
||||
table = []
|
||||
# print('key2values_relations', key2values_relations)
|
||||
# logger.info('key2values_relations', key2values_relations)
|
||||
for key_group_id, list_value_group_ids in key2values_relations.items():
|
||||
if len(list_value_group_ids) == 0: continue
|
||||
elif len(list_value_group_ids) == 1:
|
||||
@ -355,7 +361,7 @@ def get_vat_information(outputs):
|
||||
for pair in outputs['single']:
|
||||
for raw_key_name, value in pair.items():
|
||||
key_name, score, proceessed_text = vat_standardizer(raw_key_name, threshold=0.8, header=False)
|
||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs.keys()):
|
||||
single_pairs[key_name].append({
|
||||
@ -369,7 +375,7 @@ def get_vat_information(outputs):
|
||||
for key, value_list in triplet.items():
|
||||
if len(value_list) == 1:
|
||||
key_name, score, proceessed_text = vat_standardizer(key, threshold=0.8, header=False)
|
||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs.keys()):
|
||||
single_pairs[key_name].append({
|
||||
@ -381,7 +387,7 @@ def get_vat_information(outputs):
|
||||
|
||||
for pair in value_list:
|
||||
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
|
||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs.keys()):
|
||||
single_pairs[key_name].append({
|
||||
@ -394,7 +400,7 @@ def get_vat_information(outputs):
|
||||
for table_row in outputs['table']:
|
||||
for pair in table_row:
|
||||
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
|
||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs.keys()):
|
||||
single_pairs[key_name].append({
|
||||
@ -461,7 +467,7 @@ def get_ap_table_information(outputs):
|
||||
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
|
||||
for cell in single_item:
|
||||
header_name, score, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=True)
|
||||
# print(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}")
|
||||
# logger.info(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}")
|
||||
if header_name in list(item.keys()):
|
||||
item[header_name].append({
|
||||
'content': cell['text'],
|
||||
@ -515,7 +521,7 @@ def get_ap_information(outputs):
|
||||
for pair in outputs['single']:
|
||||
for key_name, value in pair.items():
|
||||
key_name, score, proceessed_text = ap_standardizer(key_name, threshold=0.8, header=False)
|
||||
# print(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
# logger.info(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs):
|
||||
single_pairs[key_name].append({
|
||||
|
@ -5,6 +5,13 @@ import copy
|
||||
from utils.kvu_dictionary import vat_dictionary, ap_dictionary, DKVU2XML
|
||||
nltk.download('words')
|
||||
words = set(nltk.corpus.words.words())
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
s1 = u'ÀÁÂÃÈÉÊÌÍÒÓÔÕÙÚÝàáâãèéêìíòóôõùúýĂăĐđĨĩŨũƠơƯưẠạẢảẤấẦầẨẩẪẫẬậẮắẰằẲẳẴẵẶặẸẹẺẻẼẽẾếỀềỂểỄễỆệỈỉỊịỌọỎỏỐốỒồỔổỖỗỘộỚớỜờỞởỠỡỢợỤụỦủỨứỪừỬửỮữỰựỲỳỴỵỶỷỸỹ'
|
||||
s0 = u'AAAAEEEIIOOOOUUYaaaaeeeiioooouuyAaDdIiUuOoUuAaAaAaAaAaAaAaAaAaAaAaAaEeEeEeEeEeEeEeEeIiIiOoOoOoOoOoOoOoOoOoOoOoOoUuUuUuUuUuUuUuYyYyYyYy'
|
||||
@ -31,7 +38,7 @@ def remove_punctuation(text):
|
||||
|
||||
def remove_accents(input_str, s0, s1):
|
||||
s = ''
|
||||
# print input_str.encode('utf-8')
|
||||
# logger.info input_str.encode('utf-8')
|
||||
for c in input_str:
|
||||
if c in s1:
|
||||
s += s0[s1.index(c)]
|
||||
@ -159,7 +166,7 @@ def post_process_for_item(item: dict) -> dict:
|
||||
elif mis_key[0] == check_keys[2]:
|
||||
item[mis_key[0]] = (convert_format_number(item[check_keys[0]]) * convert_format_number(item[check_keys[1]])).__str__()
|
||||
except Exception as e:
|
||||
print("Cannot post process this item with error:", e)
|
||||
logger.info("Cannot post process this item with error:", e)
|
||||
return item
|
||||
|
||||
|
||||
|
@ -1,5 +1,12 @@
|
||||
import xml.etree.ElementTree as ET
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
ET.register_namespace('', "http://www.w3.org/2000/09/xmldsig#")
|
||||
|
||||
|
||||
@ -124,7 +131,7 @@ def replace_xml_values(xml_str, replacement_dict):
|
||||
formatted_date = date_obj.strftime("%Y-%m-%d")
|
||||
nlap_element.text = formatted_date
|
||||
except ValueError:
|
||||
print(f"Invalid date format for {key}: {value}")
|
||||
logger.info(f"Invalid date format for {key}: {value}")
|
||||
nlap_element.text = value
|
||||
else:
|
||||
element = root.find(f".//{key}")
|
||||
@ -133,7 +140,7 @@ def replace_xml_values(xml_str, replacement_dict):
|
||||
ET.register_namespace("", "http://www.w3.org/2000/09/xmldsig#")
|
||||
return ET.tostring(root, encoding="unicode")
|
||||
except ET.ParseError as e:
|
||||
print(f"Error parsing XML: {e}")
|
||||
logger.info(f"Error parsing XML: {e}")
|
||||
return None
|
||||
|
||||
|
||||
|
@ -5,6 +5,13 @@ det_ckpt = "yolox-s-general-text-pretrain-20221226"
|
||||
cls_ckpt = "satrn-lite-general-pretrain-20230106"
|
||||
|
||||
engine = OcrEngineForYoloX_ID_Driving(det_ckpt, cls_ckpt)
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def ocr_predict(image):
|
||||
@ -22,7 +29,7 @@ def ocr_predict(image):
|
||||
list_lines, _ = words_to_lines(lWords)
|
||||
return list_lines
|
||||
except AssertionError as e:
|
||||
print(e)
|
||||
logger.info(e)
|
||||
list_lines = []
|
||||
return list_lines
|
||||
|
||||
|
@ -3,7 +3,13 @@ from datetime import datetime
|
||||
from sklearn.metrics import classification_report
|
||||
from common.utils.utils import read_json
|
||||
from underthesea import word_tokenize
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DatetimeCorrector:
|
||||
@staticmethod
|
||||
@ -92,8 +98,6 @@ class DatetimeCorrector:
|
||||
for k, d in data.items():
|
||||
if k in lexcludes:
|
||||
continue
|
||||
if k == "inv_SDV_215":
|
||||
print("debugging")
|
||||
pred = DatetimeCorrector.correct(d["pred"])
|
||||
label = DatetimeCorrector.correct(d["label"])
|
||||
ddata[k] = {}
|
||||
@ -103,11 +107,8 @@ class DatetimeCorrector:
|
||||
ddata[k]["Post-processed"] = pred
|
||||
y_pred.append(pred == label)
|
||||
y_true.append(1)
|
||||
if k == "invoice_1219_000":
|
||||
print("\n", k, '-' * 50)
|
||||
print(pred, "------", d["pred"])
|
||||
print(label, "------", d["label"])
|
||||
print(classification_report(y_true, y_pred))
|
||||
|
||||
logger.info(classification_report(y_true, y_pred))
|
||||
import pandas as pd
|
||||
df = pd.DataFrame.from_dict(ddata, orient="index")
|
||||
df.to_csv(f"result/datetime_post_processed_{type_column}.csv")
|
@ -11,6 +11,13 @@ from common.utils_kvu.split_docs import split_docs, merge_sbt_output
|
||||
# from api.Kie_Invoice_AP.prediction_fi import predict_fi
|
||||
# from api.manulife.predict_manulife import predict as predict_manulife
|
||||
from api.sdsap_sbt.prediction_sbt import predict as predict_sbt
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
os.environ['PYTHONPATH'] = '/home/thucpd/thucpd/cope2n-ai/cope2n-ai/'
|
||||
|
||||
@ -188,11 +195,11 @@ def compile_output_manulife(list_url):
|
||||
outputs = []
|
||||
for page in list_url:
|
||||
output_model = predict_manulife(page['page_number'], page['file_url']) # gotta be predict_manulife(), for the time being, this function is not avaible, we just leave a dummy function here instead
|
||||
print("output_model", output_model)
|
||||
logger.info("output_model", output_model)
|
||||
outputs.append(output_model)
|
||||
print("outputs", outputs)
|
||||
logger.info("outputs", outputs)
|
||||
documents = split_docs(outputs)
|
||||
print("documents", documents)
|
||||
logger.info("documents", documents)
|
||||
results = {
|
||||
"total_pages": len(list_url),
|
||||
"ocr_num_pages": len(list_url),
|
||||
|
@ -1,5 +1,12 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# tuplify
|
||||
def tup(point):
|
||||
@ -85,7 +92,7 @@ while not finished:
|
||||
finished = True
|
||||
|
||||
# check progress
|
||||
print("Len Boxes: " + str(len(boxes)))
|
||||
logger.info("Len Boxes: " + str(len(boxes)))
|
||||
|
||||
# draw boxes # comment this section out to run faster
|
||||
copy = np.copy(orig)
|
||||
|
@ -4,6 +4,13 @@ from sdsvtr import StandaloneSATRNRunner
|
||||
from sdsvtd import StandaloneYOLOXRunner
|
||||
import urllib
|
||||
import cv2
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class YoloX:
|
||||
@ -50,8 +57,8 @@ class OcrEngineForYoloX_Invoice:
|
||||
lbboxes.append(bbox_)
|
||||
lcropped_img.append(crop_img)
|
||||
except AssertionError as e:
|
||||
print(e)
|
||||
print(f"[ERROR]: Skipping invalid bbox in image")
|
||||
logger.info(e)
|
||||
logger.info(f"[ERROR]: Skipping invalid bbox in image")
|
||||
lwords, _ = self.cls.inference(lcropped_img)
|
||||
return lbboxes, lwords
|
||||
|
||||
@ -72,6 +79,6 @@ class OcrEngineForYoloX_ID_Driving:
|
||||
lbboxes.append(bbox_)
|
||||
lcropped_img.append(crop_img)
|
||||
except AssertionError:
|
||||
print(f"[ERROR]: Skipping invalid bbox image in ")
|
||||
logger.info(f"[ERROR]: Skipping invalid bbox image in ")
|
||||
lwords, _ = self.cls.inference(lcropped_img)
|
||||
return lbboxes, lwords
|
||||
|
@ -5,7 +5,13 @@ from xml.dom.expatbuilder import parseString
|
||||
from lxml.etree import Element, tostring, SubElement
|
||||
import tqdm
|
||||
from common.utils.global_variables import *
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def boxes_to_xml(boxes_lst, xml_pth, img_pth=""):
|
||||
"""_summary_
|
||||
@ -75,7 +81,7 @@ def boxes_to_xml(boxes_lst, xml_pth, img_pth=""):
|
||||
node_ymax = SubElement(node_bndbox, "ymax")
|
||||
node_ymax.text = bottom
|
||||
|
||||
xml = tostring(node_root, pretty_print=True)
|
||||
xml = tostring(node_root, pretty_logger.info=True)
|
||||
dom = parseString(xml)
|
||||
with open(xml_pth, "w+", encoding="utf-8") as f:
|
||||
dom.writexml(f, indent="\t", addindent="\t", encoding="utf-8")
|
||||
@ -105,7 +111,7 @@ def check_iou(box1: Box, box2: Box, threshold=0.9):
|
||||
ymax_intersect * ymin_intersect
|
||||
)
|
||||
union = area1 + area2 - area_intersect
|
||||
print(union)
|
||||
logger.info(union)
|
||||
iou = area_intersect / area1
|
||||
if iou > threshold:
|
||||
return True
|
||||
|
@ -1,5 +1,12 @@
|
||||
from builtins import dict
|
||||
from common.utils.global_variables import *
|
||||
import logging
|
||||
import logging.config
|
||||
from utils.logging.logging import LOGGER_CONFIG
|
||||
# Load the logging configuration
|
||||
logging.config.dictConfig(LOGGER_CONFIG)
|
||||
# Get the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MIN_IOU_HEIGHT = 0.7
|
||||
MIN_WIDTH_LINE_RATIO = 0.05
|
||||
@ -62,7 +69,7 @@ class Word_group:
|
||||
if word.text != "✪":
|
||||
for w in self.list_words:
|
||||
if word.word_id == w.word_id:
|
||||
print("Word id collision")
|
||||
logger.info("Word id collision")
|
||||
return False
|
||||
word.word_group_id = self.word_group_id #
|
||||
word.line_id = self.line_id
|
||||
@ -120,7 +127,7 @@ class Line:
|
||||
if word_group.list_words is not None:
|
||||
for wg in self.list_word_groups:
|
||||
if word_group.word_group_id == wg.word_group_id:
|
||||
print("Word_group id collision")
|
||||
logger.info("Word_group id collision")
|
||||
return False
|
||||
|
||||
self.list_word_groups.append(word_group)
|
||||
@ -204,7 +211,7 @@ class Paragraph:
|
||||
if line.list_word_groups is not None:
|
||||
for l in self.list_lines:
|
||||
if line.line_id == l.line_id:
|
||||
print("Line id collision")
|
||||
logger.info("Line id collision")
|
||||
return False
|
||||
for i in range(len(line.list_word_groups)):
|
||||
line.list_word_groups[
|
||||
@ -288,7 +295,7 @@ def prepare_line(words):
|
||||
new_line.merge_word(word)
|
||||
lines.append(new_line)
|
||||
|
||||
# print(len(lines))
|
||||
# logger.info(len(lines))
|
||||
# sort line from top to bottom according top coordinate
|
||||
lines.sort(key=lambda x: x.boundingbox[1])
|
||||
return lines
|
||||
@ -381,7 +388,7 @@ def words_to_lines(words, check_special_lines=True): # words is list of Word in
|
||||
# sort word by top
|
||||
words.sort(key=lambda x: (x.boundingbox[1], x.boundingbox[0]))
|
||||
number_of_word = len(words)
|
||||
# print(number_of_word)
|
||||
# logger.info(number_of_word)
|
||||
# sort list words to list lines, which have not contained word_group yet
|
||||
lines = prepare_line(words)
|
||||
|
||||
@ -402,7 +409,7 @@ def near(word_group1: Word_group, word_group2: Word_group):
|
||||
if overlap > 0:
|
||||
return True
|
||||
if abs(overlap / min_height) < 1.5:
|
||||
print("near enough", abs(overlap / min_height), overlap, min_height)
|
||||
logger.info("near enough", abs(overlap / min_height), overlap, min_height)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -102,8 +102,6 @@ def merge_sbt_output(loutputs):
|
||||
})
|
||||
return output
|
||||
|
||||
print("concat outputs: \n", loutputs)
|
||||
|
||||
merged_output = []
|
||||
combined_output = {"retailername": None,
|
||||
"sold_to_party": None,
|
||||
|
0
cope2n-ai-fi/utils/__init__.py
Normal file
0
cope2n-ai-fi/utils/__init__.py
Normal file
0
cope2n-ai-fi/utils/logging/__init___.py
Normal file
0
cope2n-ai-fi/utils/logging/__init___.py
Normal file
15
cope2n-ai-fi/utils/logging/local_storage.py
Normal file
15
cope2n-ai-fi/utils/logging/local_storage.py
Normal file
@ -0,0 +1,15 @@
|
||||
from threading import local
|
||||
|
||||
_thread_locals = local()
|
||||
|
||||
def get_current_request():
|
||||
return getattr(_thread_locals, 'request', None)
|
||||
|
||||
def set_current_request(request):
|
||||
_thread_locals.request = request
|
||||
|
||||
def set_current_trace_id(trace_id):
|
||||
_thread_locals.trace_id = trace_id
|
||||
|
||||
def get_current_trace_id():
|
||||
return getattr(_thread_locals, 'trace_id', None)
|
61
cope2n-ai-fi/utils/logging/logging.py
Normal file
61
cope2n-ai-fi/utils/logging/logging.py
Normal file
@ -0,0 +1,61 @@
|
||||
import os
|
||||
import logging
|
||||
import logging.config
|
||||
from .local_storage import set_current_trace_id, get_current_trace_id
|
||||
|
||||
class TraceIDLogFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
trace_id = get_current_trace_id()
|
||||
record.trace_id = trace_id
|
||||
return True
|
||||
|
||||
LOG_ROOT = os.getenv("LOG_ROOT", "/home/tuanlv/workspace/02-KVU/sdsvkvu/logs")
|
||||
|
||||
LOGGER_CONFIG = {
|
||||
"version": 1,
|
||||
"formatters": {
|
||||
"default": {
|
||||
"format": "%(asctime)s - %(name)s - %(levelname)s - %(trace_id)s - %(message)s"
|
||||
}
|
||||
},
|
||||
"filters": {
|
||||
"trace_id": {
|
||||
"()": TraceIDLogFilter
|
||||
}
|
||||
},
|
||||
"handlers": {
|
||||
'console': {
|
||||
'class': 'logging.StreamHandler',
|
||||
'formatter': 'default',
|
||||
'filters': ['trace_id'],
|
||||
},
|
||||
"file_handler": {
|
||||
"class": "logging.handlers.TimedRotatingFileHandler",
|
||||
"filename": f"{LOG_ROOT}/sbt_idp_AI.log",
|
||||
"level": "DEBUG",
|
||||
"formatter": "default",
|
||||
"filters": ["trace_id"],
|
||||
"when": "midnight",
|
||||
"interval": 1,
|
||||
'backupCount': 10,
|
||||
}
|
||||
},
|
||||
"loggers": {
|
||||
"sdsvkvu": {
|
||||
"level": "DEBUG",
|
||||
"handlers": ["console", "file_handler"],
|
||||
},
|
||||
'': {
|
||||
'handlers': ['console', 'file_handler'],
|
||||
'level': 'INFO',
|
||||
},
|
||||
'django': {
|
||||
'handlers': ['console', 'file_handler'],
|
||||
'level': 'INFO',
|
||||
},
|
||||
'celery': {
|
||||
'handlers': ['console', 'file_handler'],
|
||||
'level': 'DEBUG',
|
||||
},
|
||||
}
|
||||
}
|
@ -282,6 +282,7 @@ LOGGING = {
|
||||
'console': {
|
||||
'class': 'logging.StreamHandler',
|
||||
'formatter': 'verbose',
|
||||
'filters': ['trace_id'],
|
||||
},
|
||||
'file': {
|
||||
"class": 'logging.handlers.TimedRotatingFileHandler',
|
||||
@ -290,6 +291,7 @@ LOGGING = {
|
||||
"interval": 1,
|
||||
'backupCount': 10,
|
||||
'formatter': 'verbose',
|
||||
'filters': ['trace_id'],
|
||||
},
|
||||
},
|
||||
'loggers': {
|
||||
|
@ -2,7 +2,7 @@ from celery import Celery
|
||||
|
||||
from fwd import settings
|
||||
from fwd_api.exception.exceptions import GeneralException
|
||||
from fwd_api.middleware.local_storage import get_current_request
|
||||
from fwd_api.middleware.local_storage import get_current_trace_id
|
||||
from kombu.utils.uuid import uuid
|
||||
from celery.utils.log import get_task_logger
|
||||
logger = get_task_logger(__name__)
|
||||
@ -128,9 +128,9 @@ class CeleryConnector:
|
||||
def send_task(self, name=None, args=None, countdown=None):
|
||||
if name not in self.task_routes or 'queue' not in self.task_routes[name]:
|
||||
raise GeneralException("System")
|
||||
# task_id = args[0] + "_" + uuid()[:4] if isinstance(args, tuple) and is_it_an_index(args[0]) else uuid()
|
||||
request = get_current_request()
|
||||
task_id = request.META.get('X-Trace-ID', uuid()) + "_" + uuid()[:4] if request else uuid()
|
||||
task_id = args[0] + "_" + uuid()[:4] if isinstance(args, tuple) and is_it_an_index(args[0]) else uuid()
|
||||
trace_id = get_current_trace_id()
|
||||
args += (trace_id,) # add trace_id to args then remove before start
|
||||
logger.info(f"SEND task name: {name} - {task_id} | args: {args} | countdown: {countdown}")
|
||||
return self.app.send_task(name, args, queue=self.task_routes[name]['queue'], expires=300, countdown=countdown, task_id=task_id)
|
||||
|
||||
|
@ -16,6 +16,7 @@ from ..utils import process as ProcessUtil
|
||||
from ..utils import s3 as S3Util
|
||||
from ..utils.accuracy import validate_feedback_file
|
||||
from fwd_api.constant.common import FileCategory
|
||||
from fwd_api.middleware.local_storage import get_current_trace_id
|
||||
import csv
|
||||
import json
|
||||
import copy
|
||||
@ -222,6 +223,8 @@ def process_pdf(rq_id, sub_id, p_type, user_id, files):
|
||||
file_meta["preprocessing_time"] = preprocessing_time
|
||||
file_meta["index_to_image_type"] = b_url["index_to_image_type"]
|
||||
file_meta["subsidiary"] = new_request.subsidiary
|
||||
file_meta["request_id"] = rq_id
|
||||
file_meta["trace_id"] = get_current_trace_id()
|
||||
to_queue.append((fractorized_request_id, sub_id, [b_url], user_id, p_type, file_meta))
|
||||
|
||||
# Send to next queue
|
||||
|
@ -1,5 +1,6 @@
|
||||
from celery import Task
|
||||
from celery.utils.log import get_task_logger
|
||||
from fwd_api.middleware.local_storage import get_current_trace_id, set_current_trace_id
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
class VerboseTask(Task):
|
||||
@ -13,4 +14,7 @@ class VerboseTask(Task):
|
||||
logger.info(f"SUCCESS: Task: {self.name} - {task_id} | retval: {retval} | args: {args} | kwargs: {kwargs}")
|
||||
|
||||
def before_start(self, task_id, args, kwargs):
|
||||
trace_id = args[-1]
|
||||
args.pop(-1)
|
||||
set_current_trace_id(trace_id)
|
||||
logger.info(f"BEFORE_START: Task: {self.name} - {task_id} | args: {args} | kwargs: {kwargs}")
|
@ -7,3 +7,9 @@ def get_current_request():
|
||||
|
||||
def set_current_request(request):
|
||||
_thread_locals.request = request
|
||||
|
||||
def set_current_trace_id(trace_id):
|
||||
_thread_locals.trace_id = trace_id
|
||||
|
||||
def get_current_trace_id():
|
||||
return getattr(_thread_locals, 'trace_id', None)
|
@ -2,7 +2,7 @@ import logging
|
||||
import uuid
|
||||
|
||||
from django.utils.deprecation import MiddlewareMixin
|
||||
from .local_storage import set_current_request, get_current_request
|
||||
from .local_storage import set_current_trace_id, get_current_trace_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -10,7 +10,7 @@ class LoggingMiddleware(MiddlewareMixin):
|
||||
def process_request(self, request):
|
||||
trace_id = request.headers.get('X-Trace-ID', str(uuid.uuid4()))
|
||||
request.META['X-Trace-ID'] = trace_id
|
||||
set_current_request(request)
|
||||
set_current_trace_id(trace_id)
|
||||
|
||||
request_body = ""
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
@ -41,7 +41,6 @@ class LoggingMiddleware(MiddlewareMixin):
|
||||
|
||||
class TraceIDLogFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
request = get_current_request()
|
||||
trace_id = request.META.get('X-Trace-ID', 'unknown') if request else 'unknown'
|
||||
trace_id = get_current_trace_id()
|
||||
record.trace_id = trace_id
|
||||
return True
|
@ -12,16 +12,16 @@ services:
|
||||
shm_size: 10gb
|
||||
dockerfile: Dockerfile
|
||||
shm_size: 10gb
|
||||
restart: always
|
||||
networks:
|
||||
- ctel-sbt
|
||||
privileged: true
|
||||
image: sidp/cope2n-ai-fi-sbt:latest
|
||||
# runtime: nvidia
|
||||
environment:
|
||||
- LOG_ROOT=${AI_LOG_ROOT}
|
||||
- PYTHONPATH=${PYTHONPATH}:/workspace/cope2n-ai-fi # For import module
|
||||
- CELERY_BROKER=amqp://${RABBITMQ_DEFAULT_USER}:${RABBITMQ_DEFAULT_PASS}@rabbitmq-sbt:5672
|
||||
# - CUDA_VISIBLE_DEVICES=0
|
||||
- CUDA_VISIBLE_DEVICES=1
|
||||
volumes:
|
||||
- ./cope2n-ai-fi:/workspace/cope2n-ai-fi # for dev container only
|
||||
working_dir: /workspace/cope2n-ai-fi
|
||||
|
@ -15,6 +15,7 @@ services:
|
||||
- ctel-sbt
|
||||
privileged: true
|
||||
environment:
|
||||
- LOG_ROOT=${AI_LOG_ROOT}
|
||||
- CELERY_BROKER=amqp://${RABBITMQ_DEFAULT_USER}:${RABBITMQ_DEFAULT_PASS}@rabbitmq-sbt:5672
|
||||
working_dir: /workspace/cope2n-ai-fi
|
||||
command: bash run.sh
|
||||
|
Loading…
Reference in New Issue
Block a user