commit
88899ad394
@ -81,7 +81,6 @@ def predict_image(img_path: str, save_dir: str, predictor: KVUPredictor, process
|
|||||||
# vat_outputs_invoice = export_kvu_for_all(os.path.join(save_dir, fname.replace(img_ext, '.json')), lwords[i], bbox[i], pr_class_words[i], pr_relations[i], predictor.class_names)
|
# vat_outputs_invoice = export_kvu_for_all(os.path.join(save_dir, fname.replace(img_ext, '.json')), lwords[i], bbox[i], pr_class_words[i], pr_relations[i], predictor.class_names)
|
||||||
vat_outputs_invoice = export_kvu_for_manulife(os.path.join(save_dir, fname.replace(img_ext, '.json')), lwords[i], bbox[i], pr_class_words[i], pr_relations[i], predictor.class_names)
|
vat_outputs_invoice = export_kvu_for_manulife(os.path.join(save_dir, fname.replace(img_ext, '.json')), lwords[i], bbox[i], pr_class_words[i], pr_relations[i], predictor.class_names)
|
||||||
|
|
||||||
print(vat_outputs_invoice)
|
|
||||||
return vat_outputs_invoice
|
return vat_outputs_invoice
|
||||||
|
|
||||||
|
|
||||||
@ -105,7 +104,6 @@ def show_groundtruth(dir_path: str, json_dir: str, save_dir: str, predictor: KVU
|
|||||||
list_images = []
|
list_images = []
|
||||||
for ext in ['JPG', 'PNG', 'jpeg', 'jpg', 'png']:
|
for ext in ['JPG', 'PNG', 'jpeg', 'jpg', 'png']:
|
||||||
list_images += glob.glob(os.path.join(dir_path, f'*.{ext}'))
|
list_images += glob.glob(os.path.join(dir_path, f'*.{ext}'))
|
||||||
print('No. images:', len(list_images))
|
|
||||||
for img_path in tqdm(list_images):
|
for img_path in tqdm(list_images):
|
||||||
load_groundtruth(img_path, json_dir, save_dir, predictor, processor, export_img)
|
load_groundtruth(img_path, json_dir, save_dir, predictor, processor, export_img)
|
||||||
|
|
||||||
@ -133,5 +131,4 @@ if __name__ == "__main__":
|
|||||||
create_dir(args.save_dir)
|
create_dir(args.save_dir)
|
||||||
image_path = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
|
image_path = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
|
||||||
save_dir = "/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/visualize/test"
|
save_dir = "/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/visualize/test"
|
||||||
predict_image(image_path, save_dir, predictor, processor)
|
predict_image(image_path, save_dir, predictor, processor)
|
||||||
print('[INFO] Done')
|
|
@ -7,6 +7,13 @@ from torch.utils.data.dataloader import DataLoader
|
|||||||
|
|
||||||
from lightning_modules.data_modules.kvu_dataset import KVUDataset, KVUEmbeddingDataset
|
from lightning_modules.data_modules.kvu_dataset import KVUDataset, KVUEmbeddingDataset
|
||||||
from lightning_modules.utils import _get_number_samples
|
from lightning_modules.utils import _get_number_samples
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class KVUDataModule(pl.LightningDataModule):
|
class KVUDataModule(pl.LightningDataModule):
|
||||||
def __init__(self, cfg, tokenizer_layoutxlm, feature_extractor):
|
def __init__(self, cfg, tokenizer_layoutxlm, feature_extractor):
|
||||||
@ -61,7 +68,7 @@ class KVUDataModule(pl.LightningDataModule):
|
|||||||
f"Not supported stage: {self.cfg.stage}"
|
f"Not supported stage: {self.cfg.stage}"
|
||||||
)
|
)
|
||||||
|
|
||||||
print('No. training samples:', len(dataset))
|
logger.info('No. training samples:', len(dataset))
|
||||||
|
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
@ -72,7 +79,7 @@ class KVUDataModule(pl.LightningDataModule):
|
|||||||
)
|
)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
print(f"Elapsed time for loading training data: {elapsed_time}")
|
logger.info(f"Elapsed time for loading training data: {elapsed_time}")
|
||||||
|
|
||||||
return data_loader
|
return data_loader
|
||||||
|
|
||||||
@ -101,7 +108,7 @@ class KVUDataModule(pl.LightningDataModule):
|
|||||||
f"Not supported stage: {self.cfg.stage}"
|
f"Not supported stage: {self.cfg.stage}"
|
||||||
)
|
)
|
||||||
|
|
||||||
print('No. validation samples:', len(dataset))
|
logger.info('No. validation samples:', len(dataset))
|
||||||
|
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
@ -18,6 +18,13 @@ import json
|
|||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Union, Tuple, List
|
from typing import Union, Tuple, List
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
current_dir = os.getcwd()
|
current_dir = os.getcwd()
|
||||||
|
|
||||||
|
|
||||||
@ -42,10 +49,10 @@ def get_args():
|
|||||||
|
|
||||||
|
|
||||||
def load_engine(opt) -> OcrEngine:
|
def load_engine(opt) -> OcrEngine:
|
||||||
print("[INFO] Loading engine...")
|
logger.info("Loading engine...")
|
||||||
kw = json.loads(opt.ocr_kwargs) if opt.ocr_kwargs else {}
|
kw = json.loads(opt.ocr_kwargs) if opt.ocr_kwargs else {}
|
||||||
engine = OcrEngine(**kw)
|
engine = OcrEngine(**kw)
|
||||||
print("[INFO] Engine loaded")
|
logger.info("[INFO] Engine loaded")
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
@ -64,7 +71,7 @@ def get_paths_from_opt(opt) -> Tuple[Path, Path]:
|
|||||||
Path(save_dir), Path(base_dir))
|
Path(save_dir), Path(base_dir))
|
||||||
if not save_dir.exists():
|
if not save_dir.exists():
|
||||||
save_dir.mkdir()
|
save_dir.mkdir()
|
||||||
print("[INFO]: Creating folder ", save_dir)
|
logger.info("Creating folder ", save_dir)
|
||||||
return input_image, save_dir
|
return input_image, save_dir
|
||||||
|
|
||||||
|
|
||||||
@ -105,7 +112,7 @@ def process_dir(
|
|||||||
img_path.stem + ".txt"))
|
img_path.stem + ".txt"))
|
||||||
process_img(img, save_path, engine, export_img)
|
process_img(img, save_path, engine, export_img)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print('[ERROR]: ', e, ' at ', simg_path)
|
logger.error(e, ' at ', simg_path)
|
||||||
continue
|
continue
|
||||||
ddata["img_path"].append(simg_path)
|
ddata["img_path"].append(simg_path)
|
||||||
ddata["ocr_path"].append(save_path)
|
ddata["ocr_path"].append(save_path)
|
||||||
@ -125,7 +132,6 @@ def process_csv(csv_path: str, engine: OcrEngine) -> None:
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
opt = get_args()
|
opt = get_args()
|
||||||
engine = load_engine(opt)
|
engine = load_engine(opt)
|
||||||
print("[INFO]: OCR engine settings:", engine.settings)
|
|
||||||
img, save_dir = get_paths_from_opt(opt)
|
img, save_dir = get_paths_from_opt(opt)
|
||||||
|
|
||||||
lskip_dir = []
|
lskip_dir = []
|
||||||
@ -137,7 +143,6 @@ if __name__ == "__main__":
|
|||||||
elif img.suffix in ImageReader.supported_ext:
|
elif img.suffix in ImageReader.supported_ext:
|
||||||
process_img(str(img), save_dir, engine, opt.export_img)
|
process_img(str(img), save_dir, engine, opt.export_img)
|
||||||
elif img.suffix == '.csv':
|
elif img.suffix == '.csv':
|
||||||
print("[WARNING]: Running with csv file will ignore the save_dir argument. Instead, the ocr_path in the csv would be used")
|
|
||||||
process_csv(img, engine)
|
process_csv(img, engine)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('[ERROR]: Unsupported file {}'.format(img))
|
raise NotImplementedError('[ERROR]: Unsupported file {}'.format(img))
|
||||||
|
@ -3,7 +3,13 @@ from typing import Optional, List
|
|||||||
import cv2
|
import cv2
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from .utils import visualize_bbox_and_label
|
from .utils import visualize_bbox_and_label
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class Box:
|
class Box:
|
||||||
def __init__(self, x1, y1, x2, y2, conf=-1., label=""):
|
def __init__(self, x1, y1, x2, y2, conf=-1., label=""):
|
||||||
@ -189,7 +195,7 @@ class Word_group:
|
|||||||
if word.text != "✪":
|
if word.text != "✪":
|
||||||
for w in self.list_words:
|
for w in self.list_words:
|
||||||
if word.word_id == w.word_id:
|
if word.word_id == w.word_id:
|
||||||
print("Word id collision")
|
logger.info("Word id collision")
|
||||||
return False
|
return False
|
||||||
word.word_group_id = self.word_group_id #
|
word.word_group_id = self.word_group_id #
|
||||||
word.line_id = self.line_id
|
word.line_id = self.line_id
|
||||||
@ -260,7 +266,7 @@ class Line:
|
|||||||
if word_group.list_words is not None:
|
if word_group.list_words is not None:
|
||||||
for wg in self.list_word_groups:
|
for wg in self.list_word_groups:
|
||||||
if word_group.word_group_id == wg.word_group_id:
|
if word_group.word_group_id == wg.word_group_id:
|
||||||
print("Word_group id collision")
|
logger.info("Word_group id collision")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self.list_word_groups.append(word_group)
|
self.list_word_groups.append(word_group)
|
||||||
@ -352,7 +358,7 @@ class Paragraph:
|
|||||||
if line.list_word_groups is not None:
|
if line.list_word_groups is not None:
|
||||||
for l in self.list_lines:
|
for l in self.list_lines:
|
||||||
if line.line_id == l.line_id:
|
if line.line_id == l.line_id:
|
||||||
print("Line id collision")
|
logger.info("Line id collision")
|
||||||
return False
|
return False
|
||||||
for i in range(len(line.list_word_groups)):
|
for i in range(len(line.list_word_groups)):
|
||||||
line.list_word_groups[
|
line.list_word_groups[
|
||||||
|
@ -16,7 +16,13 @@ from .dto import Word, Line, Page, Document, Box
|
|||||||
# from .word_formation import wo rds_to_lines_mmocr as words_to_lines
|
# from .word_formation import wo rds_to_lines_mmocr as words_to_lines
|
||||||
from .word_formation import words_to_lines_tesseract as words_to_lines
|
from .word_formation import words_to_lines_tesseract as words_to_lines
|
||||||
DEFAULT_SETTING_PATH = str(Path(__file__).parents[1]) + "/settings.yml"
|
DEFAULT_SETTING_PATH = str(Path(__file__).parents[1]) + "/settings.yml"
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class OcrEngine:
|
class OcrEngine:
|
||||||
def __init__(self, settings_file: str = DEFAULT_SETTING_PATH, **kwargs: dict):
|
def __init__(self, settings_file: str = DEFAULT_SETTING_PATH, **kwargs: dict):
|
||||||
@ -35,7 +41,7 @@ class OcrEngine:
|
|||||||
|
|
||||||
if "cuda" in self.__settings["device"]:
|
if "cuda" in self.__settings["device"]:
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
print("[WARNING]: CUDA is not available, running with cpu instead")
|
logger.warning("CUDA is not available, running with cpu instead")
|
||||||
self.__settings["device"] = "cpu"
|
self.__settings["device"] = "cpu"
|
||||||
self._detector = StandaloneYOLOXRunner(
|
self._detector = StandaloneYOLOXRunner(
|
||||||
version=self.__settings["detector"],
|
version=self.__settings["detector"],
|
||||||
|
@ -12,7 +12,13 @@ from pdf2image import convert_from_path
|
|||||||
from deskew import determine_skew
|
from deskew import determine_skew
|
||||||
from jdeskew.estimator import get_angle
|
from jdeskew.estimator import get_angle
|
||||||
from jdeskew.utility import rotate as jrotate
|
from jdeskew.utility import rotate as jrotate
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def post_process_recog(text: str) -> str:
|
def post_process_recog(text: str) -> str:
|
||||||
text = text.replace("✪", " ")
|
text = text.replace("✪", " ")
|
||||||
@ -30,7 +36,7 @@ class Timer:
|
|||||||
def __exit__(self, func: Callable, *args):
|
def __exit__(self, func: Callable, *args):
|
||||||
self.end_time = time.perf_counter()
|
self.end_time = time.perf_counter()
|
||||||
self.elapsed_time = self.end_time - self.start_time
|
self.elapsed_time = self.end_time - self.start_time
|
||||||
print(f"[INFO]: {self.name} took : {self.elapsed_time:.6f} seconds")
|
logger.info(f"{self.name} took : {self.elapsed_time:.6f} seconds")
|
||||||
|
|
||||||
|
|
||||||
def rotate(
|
def rotate(
|
||||||
@ -201,8 +207,8 @@ class ImageReader:
|
|||||||
ImageReader.validate_img_path(img_path)
|
ImageReader.validate_img_path(img_path)
|
||||||
limgs.append(ImageReader._read(img_path))
|
limgs.append(ImageReader._read(img_path))
|
||||||
except (FileNotFoundError, NotImplementedError, IsADirectoryError) as e:
|
except (FileNotFoundError, NotImplementedError, IsADirectoryError) as e:
|
||||||
print("[ERROR]: ", e)
|
logger.error(e)
|
||||||
print("[INFO]: Skipping image {}".format(img_path))
|
logger.error("Skipping image {}".format(img_path))
|
||||||
return limgs
|
return limgs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -2,6 +2,14 @@ from builtins import dict
|
|||||||
from .dto import Word, Line, Word_group, Box
|
from .dto import Word, Line, Word_group, Box
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Optional, List, Tuple, Union
|
from typing import Optional, List, Tuple, Union
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
MIN_IOU_HEIGHT = 0.7
|
MIN_IOU_HEIGHT = 0.7
|
||||||
MIN_WIDTH_LINE_RATIO = 0.05
|
MIN_WIDTH_LINE_RATIO = 0.05
|
||||||
|
|
||||||
@ -485,7 +493,7 @@ def near(word_group1: Word_group, word_group2: Word_group):
|
|||||||
if overlap > 0:
|
if overlap > 0:
|
||||||
return True
|
return True
|
||||||
if abs(overlap / min_height) < 1.5:
|
if abs(overlap / min_height) < 1.5:
|
||||||
print("near enough", abs(overlap / min_height), overlap, min_height)
|
logger.info("near enough", abs(overlap / min_height), overlap, min_height)
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -9,7 +9,14 @@ sys.path.append('/mnt/ssd1T/tuanlv/02.KeyValueUnderstanding/') # TODO: ???????
|
|||||||
from lightning_modules.classifier_module import parse_initial_words, parse_subsequent_words, parse_relations
|
from lightning_modules.classifier_module import parse_initial_words, parse_subsequent_words, parse_relations
|
||||||
from model import get_model
|
from model import get_model
|
||||||
from utils import load_model_weight
|
from utils import load_model_weight
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class KVUPredictor:
|
class KVUPredictor:
|
||||||
def __init__(self, configs, class_names, dummy_idx, mode=0):
|
def __init__(self, configs, class_names, dummy_idx, mode=0):
|
||||||
@ -20,9 +27,9 @@ class KVUPredictor:
|
|||||||
self.dummy_idx = dummy_idx
|
self.dummy_idx = dummy_idx
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
|
||||||
print('[INFO] Loading Key-Value Understanding model ...')
|
logger.info('Loading Key-Value Understanding model ...')
|
||||||
self.net, cfg, self.backbone_type = self._load_model(cfg_path, ckpt_path)
|
self.net, cfg, self.backbone_type = self._load_model(cfg_path, ckpt_path)
|
||||||
print("[INFO] Loaded model")
|
logger.info("Loaded model")
|
||||||
|
|
||||||
if mode == 3:
|
if mode == 3:
|
||||||
self.max_window_count = cfg.train.max_window_count
|
self.max_window_count = cfg.train.max_window_count
|
||||||
@ -41,7 +48,7 @@ class KVUPredictor:
|
|||||||
cfg.stage = self.mode
|
cfg.stage = self.mode
|
||||||
backbone_type = cfg.model.backbone
|
backbone_type = cfg.model.backbone
|
||||||
|
|
||||||
print('[INFO] Checkpoint:', ckpt_path)
|
logger.info('Checkpoint:', ckpt_path)
|
||||||
net = get_model(cfg)
|
net = get_model(cfg)
|
||||||
load_model_weight(net, ckpt_path)
|
load_model_weight(net, ckpt_path)
|
||||||
net.to('cuda')
|
net.to('cuda')
|
||||||
|
@ -6,7 +6,13 @@ from pytorch_lightning.callbacks import ModelCheckpoint
|
|||||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||||
from pytorch_lightning.plugins import DDPPlugin
|
from pytorch_lightning.plugins import DDPPlugin
|
||||||
from utils.ema_callbacks import EMA
|
from utils.ema_callbacks import EMA
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def _update_config(cfg):
|
def _update_config(cfg):
|
||||||
cfg.save_weight_dir = os.path.join(cfg.workspace, "checkpoints")
|
cfg.save_weight_dir = os.path.join(cfg.workspace, "checkpoints")
|
||||||
@ -14,7 +20,7 @@ def _update_config(cfg):
|
|||||||
|
|
||||||
# set per-gpu batch size
|
# set per-gpu batch size
|
||||||
num_devices = torch.cuda.device_count()
|
num_devices = torch.cuda.device_count()
|
||||||
print('No. devices:', num_devices)
|
logger.info('No. devices:', num_devices)
|
||||||
for mode in ["train", "val"]:
|
for mode in ["train", "val"]:
|
||||||
new_batch_size = cfg[mode].batch_size // num_devices
|
new_batch_size = cfg[mode].batch_size // num_devices
|
||||||
cfg[mode].batch_size = new_batch_size
|
cfg[mode].batch_size = new_batch_size
|
||||||
@ -89,15 +95,15 @@ def create_exp_dir(save_dir=''):
|
|||||||
if not os.path.exists(save_dir):
|
if not os.path.exists(save_dir):
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
else:
|
else:
|
||||||
print("DIR already existed.")
|
logger.info("DIR already existed.")
|
||||||
print('Experiment dir : {}'.format(save_dir))
|
logger.info('Experiment dir : {}'.format(save_dir))
|
||||||
|
|
||||||
def create_dir(save_dir=''):
|
def create_dir(save_dir=''):
|
||||||
if not os.path.exists(save_dir):
|
if not os.path.exists(save_dir):
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
else:
|
else:
|
||||||
print("DIR already existed.")
|
logger.info("DIR already existed.")
|
||||||
print('Save dir : {}'.format(save_dir))
|
logger.info('Save dir : {}'.format(save_dir))
|
||||||
|
|
||||||
def load_checkpoint(ckpt_path, model, key_include):
|
def load_checkpoint(ckpt_path, model, key_include):
|
||||||
assert os.path.exists(ckpt_path) == True, f"Ckpt path at {ckpt_path} not exist!"
|
assert os.path.exists(ckpt_path) == True, f"Ckpt path at {ckpt_path} not exist!"
|
||||||
@ -109,7 +115,7 @@ def load_checkpoint(ckpt_path, model, key_include):
|
|||||||
state_dict[key[4:].replace(key_include + '.', "")] = state_dict[key] # remove net.something.
|
state_dict[key[4:].replace(key_include + '.', "")] = state_dict[key] # remove net.something.
|
||||||
del state_dict[key]
|
del state_dict[key]
|
||||||
model.load_state_dict(state_dict, strict=True)
|
model.load_state_dict(state_dict, strict=True)
|
||||||
print(f"Load checkpoint at {ckpt_path}")
|
logger.info(f"Load checkpoint at {ckpt_path}")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def load_model_weight(net, pretrained_model_file):
|
def load_model_weight(net, pretrained_model_file):
|
||||||
|
@ -10,25 +10,31 @@ from pdf2image import convert_from_path
|
|||||||
from dicttoxml import dicttoxml
|
from dicttoxml import dicttoxml
|
||||||
from word_preprocess import vat_standardizer, get_string, ap_standardizer
|
from word_preprocess import vat_standardizer, get_string, ap_standardizer
|
||||||
from kvu_dictionary import vat_dictionary, ap_dictionary
|
from kvu_dictionary import vat_dictionary, ap_dictionary
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def create_dir(save_dir=''):
|
def create_dir(save_dir=''):
|
||||||
if not os.path.exists(save_dir):
|
if not os.path.exists(save_dir):
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
else:
|
else:
|
||||||
print("DIR already existed.")
|
logger.info("DIR already existed.")
|
||||||
print('Save dir : {}'.format(save_dir))
|
logger.info('Save dir : {}'.format(save_dir))
|
||||||
|
|
||||||
def pdf2image(pdf_dir, save_dir):
|
def pdf2image(pdf_dir, save_dir):
|
||||||
pdf_files = glob.glob(f'{pdf_dir}/*.pdf')
|
pdf_files = glob.glob(f'{pdf_dir}/*.pdf')
|
||||||
print('No. pdf files:', len(pdf_files))
|
logger.info('No. pdf files:', len(pdf_files))
|
||||||
|
|
||||||
for file in tqdm(pdf_files):
|
for file in tqdm(pdf_files):
|
||||||
pages = convert_from_path(file, 500)
|
pages = convert_from_path(file, 500)
|
||||||
for i, page in enumerate(pages):
|
for i, page in enumerate(pages):
|
||||||
page.save(os.path.join(save_dir, os.path.basename(file).replace('.pdf', f'_{i}.jpg')), 'JPEG')
|
page.save(os.path.join(save_dir, os.path.basename(file).replace('.pdf', f'_{i}.jpg')), 'JPEG')
|
||||||
print('Done!!!')
|
logger.info('Done!!!')
|
||||||
|
|
||||||
def xyxy2xywh(bbox):
|
def xyxy2xywh(bbox):
|
||||||
return [
|
return [
|
||||||
@ -239,7 +245,7 @@ def matched_wordgroup_relations(word_groups:dict, lrelations: list) -> list:
|
|||||||
try:
|
try:
|
||||||
outputs.append([word_groups[wg_from], word_groups[wg_to]])
|
outputs.append([word_groups[wg_from], word_groups[wg_to]])
|
||||||
except:
|
except:
|
||||||
print('Not valid pair:', wg_from, wg_to)
|
logger.info('Not valid pair:', wg_from, wg_to)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@ -257,7 +263,7 @@ def export_kvu_outputs(file_path, lwords, class_words, lrelations, labels=['othe
|
|||||||
triplet_pairs = []
|
triplet_pairs = []
|
||||||
single_pairs = []
|
single_pairs = []
|
||||||
table = []
|
table = []
|
||||||
# print('key2values_relations', key2values_relations)
|
# logger.info('key2values_relations', key2values_relations)
|
||||||
for key_group_id, list_value_group_ids in key2values_relations.items():
|
for key_group_id, list_value_group_ids in key2values_relations.items():
|
||||||
if len(list_value_group_ids) == 0: continue
|
if len(list_value_group_ids) == 0: continue
|
||||||
elif len(list_value_group_ids) == 1:
|
elif len(list_value_group_ids) == 1:
|
||||||
@ -343,7 +349,7 @@ def export_kvu_for_VAT_invoice(file_path, lwords, class_words, lrelations, label
|
|||||||
for pair in outputs['single']:
|
for pair in outputs['single']:
|
||||||
for key_name, value in pair.items():
|
for key_name, value in pair.items():
|
||||||
key_name, score, proceessed_text = ap_standardizer(key_name, threshold=0.8, header=False)
|
key_name, score, proceessed_text = ap_standardizer(key_name, threshold=0.8, header=False)
|
||||||
# print(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
# logger.info(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||||
|
|
||||||
if key_name in list(single_pairs):
|
if key_name in list(single_pairs):
|
||||||
single_pairs[key_name].append({
|
single_pairs[key_name].append({
|
||||||
@ -352,8 +358,8 @@ def export_kvu_for_VAT_invoice(file_path, lwords, class_words, lrelations, label
|
|||||||
'lcs_score': score,
|
'lcs_score': score,
|
||||||
'token_id': value['id']
|
'token_id': value['id']
|
||||||
})
|
})
|
||||||
# print('='*10, file_path)
|
# logger.info('='*10, file_path)
|
||||||
# print(vat_info)
|
# logger.info(vat_info)
|
||||||
# Combine VAT information and table
|
# Combine VAT information and table
|
||||||
vat_outputs = {k: None for k in list(single_pairs)}
|
vat_outputs = {k: None for k in list(single_pairs)}
|
||||||
for key_name, list_potential_value in single_pairs.items():
|
for key_name, list_potential_value in single_pairs.items():
|
||||||
@ -387,7 +393,7 @@ def export_kvu_for_SDSAP(file_path, lwords, class_words, lrelations, labels=['ot
|
|||||||
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
|
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
|
||||||
for cell in single_item:
|
for cell in single_item:
|
||||||
header_name, score, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=True)
|
header_name, score, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=True)
|
||||||
# print(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}")
|
# logger.info(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}")
|
||||||
if header_name in list(item.keys()):
|
if header_name in list(item.keys()):
|
||||||
item[header_name].append({
|
item[header_name].append({
|
||||||
'content': cell['text'],
|
'content': cell['text'],
|
||||||
@ -436,7 +442,7 @@ def export_kvu_for_SDSAP(file_path, lwords, class_words, lrelations, labels=['ot
|
|||||||
for pair in outputs['single']:
|
for pair in outputs['single']:
|
||||||
for key_name, value in pair.items():
|
for key_name, value in pair.items():
|
||||||
key_name, score, proceessed_text = ap_standardizer(key_name, threshold=0.8, header=False)
|
key_name, score, proceessed_text = ap_standardizer(key_name, threshold=0.8, header=False)
|
||||||
# print(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
# logger.info(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||||
|
|
||||||
if key_name in list(single_pairs):
|
if key_name in list(single_pairs):
|
||||||
single_pairs[key_name].append({
|
single_pairs[key_name].append({
|
||||||
|
@ -5,12 +5,19 @@ import sys, os
|
|||||||
cur_dir = os.path.dirname(__file__)
|
cur_dir = os.path.dirname(__file__)
|
||||||
sys.path.append(os.path.join(os.path.dirname(cur_dir), "ocr-engine"))
|
sys.path.append(os.path.join(os.path.dirname(cur_dir), "ocr-engine"))
|
||||||
from src.ocr import OcrEngine
|
from src.ocr import OcrEngine
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_ocr_engine() -> OcrEngine:
|
def load_ocr_engine() -> OcrEngine:
|
||||||
print("[INFO] Loading engine...")
|
logger.info("[INFO] Loading engine...")
|
||||||
engine = OcrEngine()
|
engine = OcrEngine()
|
||||||
print("[INFO] Engine loaded")
|
logger.info("[INFO] Engine loaded")
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
def process_img(img: Union[str, np.ndarray], save_dir_or_path: str, engine: OcrEngine, export_img: bool) -> None:
|
def process_img(img: Union[str, np.ndarray], save_dir_or_path: str, engine: OcrEngine, export_img: bool) -> None:
|
||||||
|
@ -22,6 +22,13 @@ from utils.kvu_dictionary import (
|
|||||||
ap_dictionary,
|
ap_dictionary,
|
||||||
manulife_dictionary
|
manulife_dictionary
|
||||||
)
|
)
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -29,20 +36,20 @@ def create_dir(save_dir=''):
|
|||||||
if not os.path.exists(save_dir):
|
if not os.path.exists(save_dir):
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
# else:
|
# else:
|
||||||
# print("DIR already existed.")
|
# logger.info("DIR already existed.")
|
||||||
# print('Save dir : {}'.format(save_dir))
|
# logger.info('Save dir : {}'.format(save_dir))
|
||||||
|
|
||||||
def convert_pdf2img(pdf_dir, save_dir):
|
def convert_pdf2img(pdf_dir, save_dir):
|
||||||
pdf_files = glob.glob(f'{pdf_dir}/*.pdf')
|
pdf_files = glob.glob(f'{pdf_dir}/*.pdf')
|
||||||
print('No. pdf files:', len(pdf_files))
|
logger.info('No. pdf files:', len(pdf_files))
|
||||||
print(pdf_files)
|
logger.info(pdf_files)
|
||||||
|
|
||||||
for file in tqdm(pdf_files):
|
for file in tqdm(pdf_files):
|
||||||
pdf2img(file, save_dir, n_pages=-1, return_fname=False)
|
pdf2img(file, save_dir, n_pages=-1, return_fname=False)
|
||||||
# pages = convert_from_path(file, 500)
|
# pages = convert_from_path(file, 500)
|
||||||
# for i, page in enumerate(pages):
|
# for i, page in enumerate(pages):
|
||||||
# page.save(os.path.join(save_dir, os.path.basename(file).replace('.pdf', f'_{i}.jpg')), 'JPEG')
|
# page.save(os.path.join(save_dir, os.path.basename(file).replace('.pdf', f'_{i}.jpg')), 'JPEG')
|
||||||
print('Done!!!')
|
logger.info('Done!!!')
|
||||||
|
|
||||||
def pdf2img(pdf_path, save_dir, n_pages=-1, return_fname=False):
|
def pdf2img(pdf_path, save_dir, n_pages=-1, return_fname=False):
|
||||||
file_names = []
|
file_names = []
|
||||||
@ -296,7 +303,7 @@ def matched_wordgroup_relations(word_groups:dict, lrelations: list) -> list:
|
|||||||
try:
|
try:
|
||||||
outputs.append([word_groups[wg_from], word_groups[wg_to]])
|
outputs.append([word_groups[wg_from], word_groups[wg_to]])
|
||||||
except:
|
except:
|
||||||
print('Not valid pair:', wg_from, wg_to)
|
logger.info('Not valid pair:', wg_from, wg_to)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def get_single_entity(word_groups: dict, lrelations: list) -> list:
|
def get_single_entity(word_groups: dict, lrelations: list) -> list:
|
||||||
@ -324,7 +331,7 @@ def export_kvu_outputs(file_path, lwords, lbboxes, class_words, lrelations, labe
|
|||||||
triplet_pairs = []
|
triplet_pairs = []
|
||||||
single_pairs = []
|
single_pairs = []
|
||||||
table = []
|
table = []
|
||||||
# print('key2values_relations', key2values_relations)
|
# logger.info('key2values_relations', key2values_relations)
|
||||||
for key_group_id, list_value_group_ids in key2values_relations.items():
|
for key_group_id, list_value_group_ids in key2values_relations.items():
|
||||||
if len(list_value_group_ids) == 0: continue
|
if len(list_value_group_ids) == 0: continue
|
||||||
elif (len(list_value_group_ids) == 1) and (list_value_group_ids[0] not in list(header_value.keys())) and (key_group_id not in list(header_key.keys())):
|
elif (len(list_value_group_ids) == 1) and (list_value_group_ids[0] not in list(header_value.keys())) and (key_group_id not in list(header_key.keys())):
|
||||||
@ -443,7 +450,7 @@ def export_kvu_for_all(file_path, lwords, lbboxes, class_words, lrelations, labe
|
|||||||
header_list = {cell['header']: cell['header_bbox'] for row in raw_outputs['table'] for cell in row}
|
header_list = {cell['header']: cell['header_bbox'] for row in raw_outputs['table'] for cell in row}
|
||||||
if header_list:
|
if header_list:
|
||||||
header_list = dict(sorted(header_list.items(), key=lambda x: int(x[1][0])))
|
header_list = dict(sorted(header_list.items(), key=lambda x: int(x[1][0])))
|
||||||
print("Header_list:", header_list.keys())
|
logger.info("Header_list:", header_list.keys())
|
||||||
|
|
||||||
for row in raw_outputs["table"]:
|
for row in raw_outputs["table"]:
|
||||||
item = {header: None for header in list(header_list.keys())}
|
item = {header: None for header in list(header_list.keys())}
|
||||||
@ -517,7 +524,7 @@ def export_kvu_for_manulife(
|
|||||||
header_list = {cell['header']: cell['header_bbox'] for row in raw_outputs['table'] for cell in row}
|
header_list = {cell['header']: cell['header_bbox'] for row in raw_outputs['table'] for cell in row}
|
||||||
if header_list:
|
if header_list:
|
||||||
header_list = dict(sorted(header_list.items(), key=lambda x: int(x[1][0])))
|
header_list = dict(sorted(header_list.items(), key=lambda x: int(x[1][0])))
|
||||||
# print("Header_list:", header_list.keys())
|
# logger.info("Header_list:", header_list.keys())
|
||||||
|
|
||||||
for row in raw_outputs["table"]:
|
for row in raw_outputs["table"]:
|
||||||
item = {header: None for header in list(header_list.keys())}
|
item = {header: None for header in list(header_list.keys())}
|
||||||
@ -539,7 +546,7 @@ def get_vat_table_information(outputs):
|
|||||||
for single_item in outputs['table']:
|
for single_item in outputs['table']:
|
||||||
headers = [item['header'] for sublist in outputs['table'] for item in sublist if 'header' in item]
|
headers = [item['header'] for sublist in outputs['table'] for item in sublist if 'header' in item]
|
||||||
item = {k: [] for k in headers}
|
item = {k: [] for k in headers}
|
||||||
print(item)
|
logger.info(item)
|
||||||
for cell in single_item:
|
for cell in single_item:
|
||||||
# header_name, score, proceessed_text = vat_standardizer(cell['header'], threshold=0.75, header=True)
|
# header_name, score, proceessed_text = vat_standardizer(cell['header'], threshold=0.75, header=True)
|
||||||
# if header_name in list(item.keys()):
|
# if header_name in list(item.keys()):
|
||||||
@ -565,7 +572,7 @@ def get_vat_table_information(outputs):
|
|||||||
# if item["Mặt hàng"] == None:
|
# if item["Mặt hàng"] == None:
|
||||||
# continue
|
# continue
|
||||||
table.append(item)
|
table.append(item)
|
||||||
print(table)
|
logger.info(table)
|
||||||
return table
|
return table
|
||||||
|
|
||||||
def get_vat_information(outputs):
|
def get_vat_information(outputs):
|
||||||
@ -574,7 +581,7 @@ def get_vat_information(outputs):
|
|||||||
for pair in outputs['single']:
|
for pair in outputs['single']:
|
||||||
for raw_key_name, value in pair.items():
|
for raw_key_name, value in pair.items():
|
||||||
key_name, score, proceessed_text = vat_standardizer(raw_key_name, threshold=0.8, header=False)
|
key_name, score, proceessed_text = vat_standardizer(raw_key_name, threshold=0.8, header=False)
|
||||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||||
|
|
||||||
if key_name in list(single_pairs.keys()):
|
if key_name in list(single_pairs.keys()):
|
||||||
single_pairs[key_name].append({
|
single_pairs[key_name].append({
|
||||||
@ -588,7 +595,7 @@ def get_vat_information(outputs):
|
|||||||
for key, value_list in triplet.items():
|
for key, value_list in triplet.items():
|
||||||
if len(value_list) == 1:
|
if len(value_list) == 1:
|
||||||
key_name, score, proceessed_text = vat_standardizer(key, threshold=0.8, header=False)
|
key_name, score, proceessed_text = vat_standardizer(key, threshold=0.8, header=False)
|
||||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||||
|
|
||||||
if key_name in list(single_pairs.keys()):
|
if key_name in list(single_pairs.keys()):
|
||||||
single_pairs[key_name].append({
|
single_pairs[key_name].append({
|
||||||
@ -600,7 +607,7 @@ def get_vat_information(outputs):
|
|||||||
|
|
||||||
for pair in value_list:
|
for pair in value_list:
|
||||||
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
|
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
|
||||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||||
|
|
||||||
if key_name in list(single_pairs.keys()):
|
if key_name in list(single_pairs.keys()):
|
||||||
single_pairs[key_name].append({
|
single_pairs[key_name].append({
|
||||||
@ -613,7 +620,7 @@ def get_vat_information(outputs):
|
|||||||
for table_row in outputs['table']:
|
for table_row in outputs['table']:
|
||||||
for pair in table_row:
|
for pair in table_row:
|
||||||
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
|
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
|
||||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||||
|
|
||||||
if key_name in list(single_pairs.keys()):
|
if key_name in list(single_pairs.keys()):
|
||||||
single_pairs[key_name].append({
|
single_pairs[key_name].append({
|
||||||
@ -674,7 +681,7 @@ def export_kvu_for_VAT_invoice(file_path, lwords, class_words, lrelations, label
|
|||||||
vat_outputs['table'] = table
|
vat_outputs['table'] = table
|
||||||
|
|
||||||
write_to_json(file_path, vat_outputs)
|
write_to_json(file_path, vat_outputs)
|
||||||
print(vat_outputs)
|
logger.info(vat_outputs)
|
||||||
return vat_outputs
|
return vat_outputs
|
||||||
|
|
||||||
|
|
||||||
@ -686,7 +693,7 @@ def get_ap_table_information(outputs):
|
|||||||
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
|
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
|
||||||
for cell in single_item:
|
for cell in single_item:
|
||||||
header_name, score, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=True)
|
header_name, score, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=True)
|
||||||
# print(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}")
|
# logger.info(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}")
|
||||||
if header_name in list(item.keys()):
|
if header_name in list(item.keys()):
|
||||||
item[header_name].append({
|
item[header_name].append({
|
||||||
'content': cell['text'],
|
'content': cell['text'],
|
||||||
@ -740,7 +747,7 @@ def get_ap_information(outputs):
|
|||||||
for pair in outputs['single']:
|
for pair in outputs['single']:
|
||||||
for raw_key_name, value in pair.items():
|
for raw_key_name, value in pair.items():
|
||||||
key_name, score, proceessed_text = ap_standardizer(raw_key_name, threshold=0.8, header=False)
|
key_name, score, proceessed_text = ap_standardizer(raw_key_name, threshold=0.8, header=False)
|
||||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||||
|
|
||||||
if key_name in list(single_pairs):
|
if key_name in list(single_pairs):
|
||||||
single_pairs[key_name].append({
|
single_pairs[key_name].append({
|
||||||
@ -763,7 +770,7 @@ def get_ap_information(outputs):
|
|||||||
|
|
||||||
if all(v is not None for k, v in pair.items()) and is_product_info == True:
|
if all(v is not None for k, v in pair.items()) and is_product_info == True:
|
||||||
key_name, score, proceessed_text = ap_standardizer(pair['key']['text'], threshold=0.8, header=False)
|
key_name, score, proceessed_text = ap_standardizer(pair['key']['text'], threshold=0.8, header=False)
|
||||||
# print(f"{pair['key']['text']} ==> {proceessed_text} ==> {key_name} : {score} - {pair['value']['text']}")
|
# logger.info(f"{pair['key']['text']} ==> {proceessed_text} ==> {key_name} : {score} - {pair['value']['text']}")
|
||||||
|
|
||||||
if key_name in list(single_pairs):
|
if key_name in list(single_pairs):
|
||||||
single_pairs[key_name].append({
|
single_pairs[key_name].append({
|
||||||
@ -778,7 +785,7 @@ def get_ap_information(outputs):
|
|||||||
for key_name, list_potential_value in single_pairs.items():
|
for key_name, list_potential_value in single_pairs.items():
|
||||||
if len(list_potential_value) == 0: continue
|
if len(list_potential_value) == 0: continue
|
||||||
if key_name == "imei_number":
|
if key_name == "imei_number":
|
||||||
# print('list_potential_value', list_potential_value)
|
# logger.info('list_potential_value', list_potential_value)
|
||||||
# ap_outputs[key_name] = [v['content'] for v in list_potential_value if v['content'].replace(' ', '').isdigit() and len(v['content'].replace(' ', '')) > 5]
|
# ap_outputs[key_name] = [v['content'] for v in list_potential_value if v['content'].replace(' ', '').isdigit() and len(v['content'].replace(' ', '')) > 5]
|
||||||
ap_outputs[key_name] = []
|
ap_outputs[key_name] = []
|
||||||
for v in list_potential_value:
|
for v in list_potential_value:
|
||||||
|
@ -1,3 +1,12 @@
|
|||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class Word():
|
class Word():
|
||||||
def __init__(self, text="",image=None, conf_detect=0.0, conf_cls=0.0, bndbox = [-1,-1,-1,-1], kie_label =""):
|
def __init__(self, text="",image=None, conf_detect=0.0, conf_cls=0.0, bndbox = [-1,-1,-1,-1], kie_label =""):
|
||||||
self.type = "word"
|
self.type = "word"
|
||||||
@ -43,7 +52,7 @@ class Word_group():
|
|||||||
if word.text != "✪":
|
if word.text != "✪":
|
||||||
for w in self.list_words:
|
for w in self.list_words:
|
||||||
if word.word_id == w.word_id:
|
if word.word_id == w.word_id:
|
||||||
print("Word id collision")
|
logger.info("Word id collision")
|
||||||
return False
|
return False
|
||||||
word.word_group_id = self.word_group_id #
|
word.word_group_id = self.word_group_id #
|
||||||
word.line_id = self.line_id
|
word.line_id = self.line_id
|
||||||
@ -92,7 +101,7 @@ class Line():
|
|||||||
if word_group.list_words is not None:
|
if word_group.list_words is not None:
|
||||||
for wg in self.list_word_groups:
|
for wg in self.list_word_groups:
|
||||||
if word_group.word_group_id == wg.word_group_id:
|
if word_group.word_group_id == wg.word_group_id:
|
||||||
print("Word_group id collision")
|
logger.info("Word_group id collision")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self.list_word_groups.append(word_group)
|
self.list_word_groups.append(word_group)
|
||||||
@ -176,7 +185,6 @@ def words_to_lines(words, check_special_lines=True): #words is list of Word inst
|
|||||||
new_line.merge_word(word)
|
new_line.merge_word(word)
|
||||||
lines.append(new_line)
|
lines.append(new_line)
|
||||||
|
|
||||||
# print(len(lines))
|
|
||||||
#sort line from top to bottom according top coordinate
|
#sort line from top to bottom according top coordinate
|
||||||
lines.sort(key = lambda x: x.boundingbox[1])
|
lines.sort(key = lambda x: x.boundingbox[1])
|
||||||
|
|
||||||
@ -189,7 +197,6 @@ def words_to_lines(words, check_special_lines=True): #words is list of Word inst
|
|||||||
continue
|
continue
|
||||||
#left, top ,right, bottom
|
#left, top ,right, bottom
|
||||||
line_width = lines[i].boundingbox[2] - lines[i].boundingbox[0] # right - left
|
line_width = lines[i].boundingbox[2] - lines[i].boundingbox[0] # right - left
|
||||||
# print("line_width",line_width)
|
|
||||||
lines[i].list_word_groups.sort(key = lambda x: x.boundingbox[0]) #sort word in lines from left to right
|
lines[i].list_word_groups.sort(key = lambda x: x.boundingbox[0]) #sort word in lines from left to right
|
||||||
|
|
||||||
#update text for lines after sorting
|
#update text for lines after sorting
|
||||||
|
@ -4,6 +4,15 @@ import string
|
|||||||
import copy
|
import copy
|
||||||
from utils.kvu_dictionary import vat_dictionary, ap_dictionary, manulife_dictionary, DKVU2XML
|
from utils.kvu_dictionary import vat_dictionary, ap_dictionary, manulife_dictionary, DKVU2XML
|
||||||
from word2line import Word, words_to_lines
|
from word2line import Word, words_to_lines
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
nltk.download('words')
|
nltk.download('words')
|
||||||
words = set(nltk.corpus.words.words())
|
words = set(nltk.corpus.words.words())
|
||||||
|
|
||||||
@ -32,7 +41,6 @@ def remove_punctuation(text):
|
|||||||
|
|
||||||
def remove_accents(input_str, s0, s1):
|
def remove_accents(input_str, s0, s1):
|
||||||
s = ''
|
s = ''
|
||||||
# print input_str.encode('utf-8')
|
|
||||||
for c in input_str:
|
for c in input_str:
|
||||||
if c in s1:
|
if c in s1:
|
||||||
s += s0[s1.index(c)]
|
s += s0[s1.index(c)]
|
||||||
@ -44,7 +52,6 @@ def remove_spaces(text):
|
|||||||
return text.replace(' ', '')
|
return text.replace(' ', '')
|
||||||
|
|
||||||
def preprocessing(text: str):
|
def preprocessing(text: str):
|
||||||
# text = remove_english_words(text) if table else text
|
|
||||||
text = remove_punctuation(text)
|
text = remove_punctuation(text)
|
||||||
text = remove_accents(text, s0, s1)
|
text = remove_accents(text, s0, s1)
|
||||||
text = remove_spaces(text)
|
text = remove_spaces(text)
|
||||||
@ -184,7 +191,7 @@ def post_process_for_item(item: dict) -> dict:
|
|||||||
elif mis_key[0] == check_keys[2]:
|
elif mis_key[0] == check_keys[2]:
|
||||||
item[mis_key[0]] = (convert_format_number(item[check_keys[0]]) * convert_format_number(item[check_keys[1]])).__str__()
|
item[mis_key[0]] = (convert_format_number(item[check_keys[0]]) * convert_format_number(item[check_keys[1]])).__str__()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Cannot post process this item with error:", e)
|
logger.error("Cannot post process this item with error:", e)
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
|
||||||
@ -280,9 +287,9 @@ def get_string_with_word2line(lwords: list, lbboxes: list):
|
|||||||
string_after_word2line = ' '.join(list_sorted_words)
|
string_after_word2line = ' '.join(list_sorted_words)
|
||||||
|
|
||||||
if string_from_model != string_after_word2line:
|
if string_from_model != string_after_word2line:
|
||||||
print("[Warning] Word group from model is different with word2line module")
|
logger.warning("[Warning] Word group from model is different with word2line module")
|
||||||
print("Model: ", ' '.join(unique_list))
|
logger.warning("Model: ", ' '.join(unique_list))
|
||||||
print("Word2line: ", ' '.join(list_sorted_words))
|
logger.warning("Word2line: ", ' '.join(list_sorted_words))
|
||||||
|
|
||||||
return string_after_word2line
|
return string_after_word2line
|
||||||
|
|
||||||
|
@ -49,10 +49,8 @@ def predict(image_url):
|
|||||||
"confidence": output[key]['conf']
|
"confidence": output[key]['conf']
|
||||||
}
|
}
|
||||||
output_dict['fields'].append(field)
|
output_dict['fields'].append(field)
|
||||||
print(output_dict)
|
|
||||||
return output_dict
|
return output_dict
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
image_url = "/mnt/ssd1T/hoanglv/Projects/KIE/sdsvkie/demos/2022_07_25 farewell lunch.jpg"
|
image_url = "/mnt/ssd1T/hoanglv/Projects/KIE/sdsvkie/demos/2022_07_25 farewell lunch.jpg"
|
||||||
output = predict(image_url)
|
output = predict(image_url)
|
||||||
print(output)
|
|
@ -60,18 +60,12 @@ def predict_fi(page_numb, image_url):
|
|||||||
output_kie = {
|
output_kie = {
|
||||||
field_name: field_item['value'] for field_name, field_item in output.items()
|
field_name: field_item['value'] for field_name, field_item in output.items()
|
||||||
}
|
}
|
||||||
# print("Hoangggggggggggggggggggggggggggggggggggggggggggggg")
|
|
||||||
# print(output_kie)
|
|
||||||
|
|
||||||
|
|
||||||
#Phan cua Tuan
|
|
||||||
kvu_result, _ = Predictor_KVU(image_url, save_dir, predictor, processor)
|
kvu_result, _ = Predictor_KVU(image_url, save_dir, predictor, processor)
|
||||||
# print("TuanNnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn")
|
|
||||||
# print(kvu_result)
|
|
||||||
# if kvu_result['imei_number'] == None and kvu_result['serial_number'] == None:
|
# if kvu_result['imei_number'] == None and kvu_result['serial_number'] == None:
|
||||||
return kvu_result, output_kie
|
return kvu_result, output_kie
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
image_url = "/mnt/hdd2T/dxtan/TannedCung/OCR/workspace/Kie_Invoice_AP/tmp_image/{image_url}.jpg"
|
image_url = "/mnt/hdd2T/dxtan/TannedCung/OCR/workspace/Kie_Invoice_AP/tmp_image/{image_url}.jpg"
|
||||||
output = predict_fi(0, image_url)
|
output = predict_fi(0, image_url)
|
||||||
print(output)
|
|
@ -69,7 +69,6 @@ def predict(page_numb, image_url):
|
|||||||
"page": page_numb
|
"page": page_numb
|
||||||
}
|
}
|
||||||
output_dict['fields'].append(field)
|
output_dict['fields'].append(field)
|
||||||
print(output_dict)
|
|
||||||
return output_dict
|
return output_dict
|
||||||
|
|
||||||
# if kvu_result['imei_number'] == None and kvu_result['serial_number'] == None:
|
# if kvu_result['imei_number'] == None and kvu_result['serial_number'] == None:
|
||||||
@ -142,5 +141,4 @@ def predict(page_numb, image_url):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
image_url = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
|
image_url = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
|
||||||
output = predict(0, image_url)
|
output = predict(0, image_url)
|
||||||
print(output)
|
|
@ -1,106 +0,0 @@
|
|||||||
1113 773 1220 825 BEST
|
|
||||||
1243 759 1378 808 DENKI
|
|
||||||
1410 752 1487 799 (S)
|
|
||||||
1430 707 1515 748 TAX
|
|
||||||
1511 745 1598 790 PTE
|
|
||||||
1542 700 1725 740 TNVOICE
|
|
||||||
1618 742 1706 783 LTD
|
|
||||||
1783 725 1920 773 FUNAN
|
|
||||||
1943 723 2054 767 MALL
|
|
||||||
1434 797 1576 843 WORTH
|
|
||||||
1599 785 1760 831 BRIDGE
|
|
||||||
1784 778 1846 822 RD
|
|
||||||
1277 846 1632 897 #02-16/#03-1
|
|
||||||
1655 832 1795 877 FUNAN
|
|
||||||
1817 822 1931 869 MALL
|
|
||||||
1272 897 1518 956 S(179105)
|
|
||||||
1548 890 1655 943 TEL:
|
|
||||||
1686 877 1911 928 69046183
|
|
||||||
1247 1011 1334 1068 GST
|
|
||||||
1358 1006 1447 1059 REG
|
|
||||||
1360 1063 1449 1115 RCB
|
|
||||||
1473 1003 1575 1055 NO.:
|
|
||||||
1474 1059 1555 1110 NO.
|
|
||||||
1595 1042 1868 1096 198202199E
|
|
||||||
1607 985 1944 1040 M2-0053813-7
|
|
||||||
1056 1134 1254 1194 Opening
|
|
||||||
1276 1127 1391 1181 Hrs:
|
|
||||||
1425 1112 1647 1170 10:00:00
|
|
||||||
1672 1102 1735 1161 AN
|
|
||||||
1755 1101 1819 1157 to
|
|
||||||
1846 1090 2067 1147 10:00:00
|
|
||||||
2090 1080 2156 1141 PH
|
|
||||||
1061 1308 1228 1366 Staff:
|
|
||||||
1258 1300 1378 1357 3296
|
|
||||||
1710 1283 1880 1337 Trans:
|
|
||||||
1936 1266 2192 1322 262152554
|
|
||||||
1060 1372 1201 1429 Date:
|
|
||||||
1260 1358 1494 1419 22-03-23
|
|
||||||
1540 1344 1664 1409 9:05
|
|
||||||
1712 1339 1856 1407 Slip:
|
|
||||||
1917 1328 2196 1387 2000130286
|
|
||||||
1124 1487 1439 1545 SALESPERSON
|
|
||||||
1465 1477 1601 1537 CODE.
|
|
||||||
1633 1471 1752 1530 6043
|
|
||||||
1777 1462 2004 1519 HUHAHHAD
|
|
||||||
2032 1451 2177 1509 RAZIH
|
|
||||||
1070 1558 1187 1617 Item
|
|
||||||
1211 1554 1276 1615 No
|
|
||||||
1439 1542 1585 1601 Price
|
|
||||||
1750 1530 1841 1597 Qty
|
|
||||||
1951 1517 2120 1579 Amount
|
|
||||||
1076 1683 1276 1741 ANDROID
|
|
||||||
1304 1673 1477 1733 TABLET
|
|
||||||
1080 1746 1280 1804 2105976
|
|
||||||
1509 1729 1705 1784 SAMSUNG
|
|
||||||
1734 1719 1931 1776 SH-P613
|
|
||||||
1964 1709 2101 1768 128GB
|
|
||||||
1082 1809 1285 1869 SM-P613
|
|
||||||
1316 1802 1454 1860 12838
|
|
||||||
1429 1859 1600 1919 518.00
|
|
||||||
1481 1794 1596 1855 WIFI
|
|
||||||
1622 1790 1656 1850 G
|
|
||||||
1797 1845 1824 1904 1
|
|
||||||
1993 1832 2165 1892 518.00
|
|
||||||
1088 1935 1347 1995 PROMOTION
|
|
||||||
1091 2000 1294 2062 2105664
|
|
||||||
1520 1983 1717 2039 SAMSUNG
|
|
||||||
1743 1963 2106 2030 F-Sam-Redeen
|
|
||||||
1439 2111 1557 2173 0.00
|
|
||||||
1806 2095 1832 2156 1
|
|
||||||
2053 2081 2174 2144 0.00
|
|
||||||
1106 2248 1250 2312 Total
|
|
||||||
1974 2206 2146 2266 518.00
|
|
||||||
1107 2312 1204 2377 UOB
|
|
||||||
1448 2291 1567 2355 CARD
|
|
||||||
1978 2268 2147 2327 518.00
|
|
||||||
1253 2424 1375 2497 GST%
|
|
||||||
1456 2411 1655 2475 Net.Amt
|
|
||||||
1818 2393 1912 2460 GST
|
|
||||||
2023 2387 2192 2445 Amount
|
|
||||||
1106 2494 1231 2560 GST8
|
|
||||||
1486 2472 1661 2537 479.63
|
|
||||||
1770 2458 1916 2523 38.37
|
|
||||||
2027 2448 2203 2511 518.00
|
|
||||||
1553 2601 1699 2666 THANK
|
|
||||||
1721 2592 1821 2661 YOU
|
|
||||||
1436 2678 1616 2749 please
|
|
||||||
1644 2682 1764 2732 come
|
|
||||||
1790 2660 1942 2729 again
|
|
||||||
1191 2862 1391 2931 Those
|
|
||||||
1426 2870 2018 2945 facebook.com
|
|
||||||
1565 2809 1690 2884 join
|
|
||||||
1709 2816 1777 2870 us
|
|
||||||
1799 2811 1868 2865 on
|
|
||||||
1838 2946 2024 3003 com .89
|
|
||||||
1533 3006 2070 3088 ar.com/askbe
|
|
||||||
1300 3326 1659 3446 That's
|
|
||||||
1696 3308 1905 3424 not
|
|
||||||
1937 3289 2131 3408 all!
|
|
||||||
1450 3511 1633 3573 SCAN
|
|
||||||
1392 3589 1489 3645 QR
|
|
||||||
1509 3577 1698 3635 CODE
|
|
||||||
1321 3656 1370 3714 &
|
|
||||||
1517 3638 1768 3699 updates
|
|
||||||
1643 3882 1769 3932 Scan
|
|
||||||
1789 3868 1859 3926 Me
|
|
Binary file not shown.
Before Width: | Height: | Size: 1.1 MiB |
@ -100,7 +100,6 @@ def word_to_line(list_words):
|
|||||||
"""
|
"""
|
||||||
texts, boundingboxes = [], []
|
texts, boundingboxes = [], []
|
||||||
for line in list_words:
|
for line in list_words:
|
||||||
print(line.text)
|
|
||||||
if line.text == "":
|
if line.text == "":
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
|
@ -6,6 +6,13 @@ det_ckpt = "/models/sdsvtd/hub/wild_receipt_finetune_weights_c_lite.pth"
|
|||||||
cls_ckpt = "satrn-lite-general-pretrain-20230106"
|
cls_ckpt = "satrn-lite-general-pretrain-20230106"
|
||||||
|
|
||||||
engine = OcrEngineForYoloX_Invoice(det_ckpt, cls_ckpt)
|
engine = OcrEngineForYoloX_Invoice(det_ckpt, cls_ckpt)
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def ocr_predict(img):
|
def ocr_predict(img):
|
||||||
@ -24,7 +31,7 @@ def ocr_predict(img):
|
|||||||
return list_lines
|
return list_lines
|
||||||
# return lbboxes, lwords
|
# return lbboxes, lwords
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
print(e)
|
logger.info(e)
|
||||||
list_lines = []
|
list_lines = []
|
||||||
return list_lines
|
return list_lines
|
||||||
|
|
||||||
|
@ -9,16 +9,23 @@ sys.path.append(cur_dir)
|
|||||||
from modules.sdsvkvu import load_engine, process_img
|
from modules.sdsvkvu import load_engine, process_img
|
||||||
from modules.ocr_engine import OcrEngine
|
from modules.ocr_engine import OcrEngine
|
||||||
from configs.manulife import device, ocr_cfg, kvu_cfg
|
from configs.manulife import device, ocr_cfg, kvu_cfg
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def load_ocr_engine(opt) -> OcrEngine:
|
def load_ocr_engine(opt) -> OcrEngine:
|
||||||
print("[INFO] Loading engine...")
|
logger.info("[INFO] Loading engine...")
|
||||||
engine = OcrEngine(**opt)
|
engine = OcrEngine(**opt)
|
||||||
print("[INFO] Engine loaded")
|
logger.info("[INFO] Engine loaded")
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
print("OCR engine configfs: \n", ocr_cfg)
|
logger.info("OCR engine configfs: \n", ocr_cfg)
|
||||||
print("KVU configfs: \n", kvu_cfg)
|
logger.info("KVU configfs: \n", kvu_cfg)
|
||||||
|
|
||||||
ocr_engine = load_ocr_engine(ocr_cfg)
|
ocr_engine = load_ocr_engine(ocr_cfg)
|
||||||
kvu_cfg['ocr_engine'] = ocr_engine
|
kvu_cfg['ocr_engine'] = ocr_engine
|
||||||
@ -86,7 +93,7 @@ def predict(page_numb, image_url):
|
|||||||
"page": page_numb
|
"page": page_numb
|
||||||
}
|
}
|
||||||
output_dict['fields'].append(field)
|
output_dict['fields'].append(field)
|
||||||
print(output_dict)
|
logger.info(output_dict)
|
||||||
return output_dict
|
return output_dict
|
||||||
|
|
||||||
|
|
||||||
@ -95,4 +102,4 @@ def predict(page_numb, image_url):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
image_url = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
|
image_url = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
|
||||||
output = predict(0, image_url)
|
output = predict(0, image_url)
|
||||||
print(output)
|
logger.info(output)
|
@ -14,9 +14,17 @@ nltk.data.path.append(os.path.join((os.getcwd() + '/nltk_data')))
|
|||||||
|
|
||||||
from modules.sdsvkvu import load_engine, process_img
|
from modules.sdsvkvu import load_engine, process_img
|
||||||
from configs.sdsap_sbt import device, ocr_cfg, kvu_cfg
|
from configs.sdsap_sbt import device, ocr_cfg, kvu_cfg
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
|
||||||
print("OCR engine configfs: \n", ocr_cfg)
|
# Load the logging configuration
|
||||||
print("KVU configfs: \n", kvu_cfg)
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
logger.info("OCR engine configfs: \n", ocr_cfg)
|
||||||
|
logger.info("KVU configfs: \n", kvu_cfg)
|
||||||
|
|
||||||
# ocr_engine = load_ocr_engine(ocr_cfg)
|
# ocr_engine = load_ocr_engine(ocr_cfg)
|
||||||
# kvu_cfg['ocr_engine'] = ocr_engine
|
# kvu_cfg['ocr_engine'] = ocr_engine
|
||||||
@ -40,7 +48,7 @@ def sbt_predict(image_url, engine, metadata={}) -> None:
|
|||||||
query_params = urllib.parse.parse_qs(parsed_url.query)
|
query_params = urllib.parse.parse_qs(parsed_url.query)
|
||||||
file_name = query_params['file_name'][0]
|
file_name = query_params['file_name'][0]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[ERROR]: Error extracting file name from url: {image_url}")
|
logger.info(f"[ERROR]: Error extracting file name from url: {image_url}")
|
||||||
file_name = f"{uuid.uuid4()}.jpg"
|
file_name = f"{uuid.uuid4()}.jpg"
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
# image_path = os.path.join(save_dir, f"{image_url}.jpg")
|
# image_path = os.path.join(save_dir, f"{image_url}.jpg")
|
||||||
@ -103,4 +111,4 @@ def predict(page_numb, image_url, metadata={}):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
image_url = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
|
image_url = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
|
||||||
output = predict(0, image_url)
|
output = predict(0, image_url)
|
||||||
print(output)
|
logger.info(output)
|
@ -1,81 +0,0 @@
|
|||||||
from celery import Celery
|
|
||||||
import base64
|
|
||||||
import environ
|
|
||||||
env = environ.Env(
|
|
||||||
DEBUG=(bool, False)
|
|
||||||
)
|
|
||||||
|
|
||||||
class CeleryConnector:
|
|
||||||
task_routes = {
|
|
||||||
"process_id_result": {"queue": "id_card_rs"},
|
|
||||||
"process_driver_license_result": {"queue": "driver_license_rs"},
|
|
||||||
"process_invoice_result": {"queue": "invoice_rs"},
|
|
||||||
"process_ocr_with_box_result": {"queue": "ocr_with_box_rs"},
|
|
||||||
"process_template_matching_result": {"queue": "template_matching_rs"},
|
|
||||||
# mock task
|
|
||||||
"process_id": {"queue": "id_card"},
|
|
||||||
"process_driver_license": {"queue": "driver_license"},
|
|
||||||
"process_invoice": {"queue": "invoice"},
|
|
||||||
"process_ocr_with_box": {"queue": "ocr_with_box"},
|
|
||||||
"process_template_matching": {"queue": "template_matching"},
|
|
||||||
}
|
|
||||||
app = Celery(
|
|
||||||
"postman",
|
|
||||||
broker=env.str("CELERY_BROKER", "amqp://test:test@rabbitmq:5672"),
|
|
||||||
broker_transport_options={'confirm_publish': False},
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_id_result(self, args):
|
|
||||||
return self.send_task("process_id_result", args)
|
|
||||||
|
|
||||||
def process_driver_license_result(self, args):
|
|
||||||
return self.send_task("process_driver_license_result", args)
|
|
||||||
|
|
||||||
def process_invoice_result(self, args):
|
|
||||||
return self.send_task("process_invoice_result", args)
|
|
||||||
|
|
||||||
def process_ocr_with_box_result(self, args):
|
|
||||||
return self.send_task("process_ocr_with_box_result", args)
|
|
||||||
|
|
||||||
def process_template_matching_result(self, args):
|
|
||||||
return self.send_task("process_template_matching_result", args)
|
|
||||||
|
|
||||||
def process_id(self, args):
|
|
||||||
return self.send_task("process_id", args)
|
|
||||||
|
|
||||||
def process_driver_license(self, args):
|
|
||||||
return self.send_task("process_driver_license", args)
|
|
||||||
|
|
||||||
def process_invoice(self, args):
|
|
||||||
return self.send_task("process_invoice", args)
|
|
||||||
|
|
||||||
def process_ocr_with_box(self, args):
|
|
||||||
return self.send_task("process_ocr_with_box", args)
|
|
||||||
|
|
||||||
def process_template_matching(self, args):
|
|
||||||
return self.send_task("process_template_matching", args)
|
|
||||||
|
|
||||||
def send_task(self, name=None, args=None):
|
|
||||||
if name not in self.task_routes or "queue" not in self.task_routes[name]:
|
|
||||||
return self.app.send_task(name, args)
|
|
||||||
|
|
||||||
return self.app.send_task(name, args, queue=self.task_routes[name]["queue"])
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
rq_id = 345
|
|
||||||
file_names = "abc.jpg"
|
|
||||||
list_data = []
|
|
||||||
|
|
||||||
with open("/home/sds/thucpd/aicr-2022/abc.jpg", "rb") as fs:
|
|
||||||
encoded_string = base64.b64encode(fs.read()).decode("utf-8")
|
|
||||||
list_data.append(encoded_string)
|
|
||||||
|
|
||||||
c_connector = CeleryConnector()
|
|
||||||
a = c_connector.process_id(args=(rq_id, list_data, file_names))
|
|
||||||
|
|
||||||
print(a)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,5 +1,7 @@
|
|||||||
from celery import Celery
|
from celery import Celery
|
||||||
import environ
|
import environ
|
||||||
|
from utils.logging.local_storage import get_current_trace_id
|
||||||
|
|
||||||
env = environ.Env(
|
env = environ.Env(
|
||||||
DEBUG=(bool, False)
|
DEBUG=(bool, False)
|
||||||
)
|
)
|
||||||
@ -53,5 +55,6 @@ class CeleryConnector:
|
|||||||
def send_task(self, name=None, args=None):
|
def send_task(self, name=None, args=None):
|
||||||
if name not in self.task_routes or "queue" not in self.task_routes[name]:
|
if name not in self.task_routes or "queue" not in self.task_routes[name]:
|
||||||
return self.app.send_task(name, args)
|
return self.app.send_task(name, args)
|
||||||
|
trace_id = get_current_trace_id()
|
||||||
|
args += (trace_id,) # add trace_id to args then remove before start
|
||||||
return self.app.send_task(name, args, queue=self.task_routes[name]["queue"])
|
return self.app.send_task(name, args, queue=self.task_routes[name]["queue"])
|
@ -1,220 +0,0 @@
|
|||||||
from celery_worker.worker import app
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
@app.task(name="process_id")
|
|
||||||
def process_id(rq_id, sub_id, folder_name, list_url, user_id):
|
|
||||||
from common.serve_model import predict
|
|
||||||
from celery_worker.client_connector import CeleryConnector
|
|
||||||
|
|
||||||
c_connector = CeleryConnector()
|
|
||||||
try:
|
|
||||||
result = predict(rq_id, sub_id, folder_name, list_url, user_id, infer_name="id_card")
|
|
||||||
print(result)
|
|
||||||
result = {
|
|
||||||
"status": 200,
|
|
||||||
"content": result,
|
|
||||||
"message": "Success",
|
|
||||||
}
|
|
||||||
c_connector.process_id_result((rq_id, result))
|
|
||||||
return {"rq_id": rq_id}
|
|
||||||
# if image_croped is not None:
|
|
||||||
# if result["data"] == []:
|
|
||||||
# result = {
|
|
||||||
# "status": 404,
|
|
||||||
# "content": {},
|
|
||||||
# }
|
|
||||||
# c_connector.process_id_result((rq_id, result, None))
|
|
||||||
# return {"rq_id": rq_id}
|
|
||||||
# else:
|
|
||||||
# result = {
|
|
||||||
# "status": 200,
|
|
||||||
# "content": result,
|
|
||||||
# "message": "Success",
|
|
||||||
# }
|
|
||||||
# c_connector.process_id_result((rq_id, result))
|
|
||||||
# return {"rq_id": rq_id}
|
|
||||||
# elif image_croped is None:
|
|
||||||
# result = {
|
|
||||||
# "status": 404,
|
|
||||||
# "content": {},
|
|
||||||
# }
|
|
||||||
# c_connector.process_id_result((rq_id, result, None))
|
|
||||||
# return {"rq_id": rq_id}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
result = {
|
|
||||||
"status": 404,
|
|
||||||
"content": {},
|
|
||||||
}
|
|
||||||
c_connector.process_id_result((rq_id, result, None))
|
|
||||||
return {"rq_id": rq_id}
|
|
||||||
|
|
||||||
|
|
||||||
@app.task(name="process_driver_license")
|
|
||||||
def process_driver_license(rq_id, sub_id, folder_name, list_url, user_id):
|
|
||||||
from common.serve_model import predict
|
|
||||||
from celery_worker.client_connector import CeleryConnector
|
|
||||||
|
|
||||||
c_connector = CeleryConnector()
|
|
||||||
try:
|
|
||||||
result = predict(rq_id, sub_id, folder_name, list_url, user_id, infer_name="driving_license")
|
|
||||||
result = {
|
|
||||||
"status": 200,
|
|
||||||
"content": result,
|
|
||||||
"message": "Success",
|
|
||||||
}
|
|
||||||
c_connector.process_driver_license_result((rq_id, result))
|
|
||||||
return {"rq_id": rq_id}
|
|
||||||
# result, image_croped = predict(str(url), "driving_license")
|
|
||||||
# if image_croped is not None:
|
|
||||||
# if result["data"] == []:
|
|
||||||
# result = {
|
|
||||||
# "status": 404,
|
|
||||||
# "content": {},
|
|
||||||
# }
|
|
||||||
# c_connector.process_driver_license_result((rq_id, result, None))
|
|
||||||
# return {"rq_id": rq_id}
|
|
||||||
# else:
|
|
||||||
# result = {
|
|
||||||
# "status": 200,
|
|
||||||
# "content": result,
|
|
||||||
# "message": "Success",
|
|
||||||
# }
|
|
||||||
# path_image_croped = "/app/media/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id)
|
|
||||||
# cv2.imwrite("/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id), image_croped)
|
|
||||||
# c_connector.process_driver_license_result((rq_id, result, path_image_croped))
|
|
||||||
# return {"rq_id": rq_id}
|
|
||||||
# elif image_croped is None:
|
|
||||||
# result = {
|
|
||||||
# "status": 404,
|
|
||||||
# "content": {},
|
|
||||||
# }
|
|
||||||
# c_connector.process_driver_license_result((rq_id, result, None))
|
|
||||||
# return {"rq_id": rq_id}
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
result = {
|
|
||||||
"status": 404,
|
|
||||||
"content": {},
|
|
||||||
}
|
|
||||||
c_connector.process_driver_license_result((rq_id, result, None))
|
|
||||||
return {"rq_id": rq_id}
|
|
||||||
|
|
||||||
|
|
||||||
@app.task(name="process_template_matching")
|
|
||||||
def process_template_matching(rq_id, sub_id, folder_name, url, tmp_json, user_id):
|
|
||||||
from TemplateMatching.src.ocr_master import Extractor
|
|
||||||
from celery_worker.client_connector import CeleryConnector
|
|
||||||
import urllib
|
|
||||||
|
|
||||||
c_connector = CeleryConnector()
|
|
||||||
extractor = Extractor()
|
|
||||||
try:
|
|
||||||
req = urllib.request.urlopen(url)
|
|
||||||
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
|
|
||||||
img = cv2.imdecode(arr, -1)
|
|
||||||
imgs = [img]
|
|
||||||
image_aliged = extractor.image_alige(imgs, tmp_json)
|
|
||||||
if image_aliged is None:
|
|
||||||
result = {
|
|
||||||
"status": 401,
|
|
||||||
"content": "Image is not match with template",
|
|
||||||
}
|
|
||||||
c_connector.process_template_matching_result(
|
|
||||||
(rq_id, result, None)
|
|
||||||
)
|
|
||||||
return {"rq_id": rq_id}
|
|
||||||
else:
|
|
||||||
output = extractor.extract_information(
|
|
||||||
image_aliged, tmp_json
|
|
||||||
)
|
|
||||||
path_image_croped = "/app/media/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id)
|
|
||||||
cv2.imwrite("/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id), image_aliged)
|
|
||||||
if output == {}:
|
|
||||||
result = {"status": 404, "content": {}}
|
|
||||||
c_connector.process_template_matching_result((rq_id, result, None))
|
|
||||||
return {"rq_id": rq_id}
|
|
||||||
else:
|
|
||||||
result = {
|
|
||||||
"document_type": "template_matching",
|
|
||||||
"fields": []
|
|
||||||
}
|
|
||||||
print(output)
|
|
||||||
for field in tmp_json["fields"]:
|
|
||||||
print(field["label"])
|
|
||||||
field_value = {
|
|
||||||
"label": field["label"],
|
|
||||||
"value": output[field["label"]],
|
|
||||||
"box": [float(num) for num in field["box"]],
|
|
||||||
"confidence": 0.98 #TODO confidence
|
|
||||||
}
|
|
||||||
result["fields"].append(field_value)
|
|
||||||
|
|
||||||
print(result)
|
|
||||||
result = {"status": 200, "content": result}
|
|
||||||
c_connector.process_template_matching_result(
|
|
||||||
(rq_id, result, path_image_croped)
|
|
||||||
)
|
|
||||||
return {"rq_id": rq_id}
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
result = {"status": 404, "content": {}}
|
|
||||||
c_connector.process_template_matching_result((rq_id, result, None))
|
|
||||||
return {"rq_id": rq_id}
|
|
||||||
|
|
||||||
|
|
||||||
# @app.task(name="process_invoice")
|
|
||||||
# def process_invoice(rq_id, url):
|
|
||||||
# from celery_worker.client_connector import CeleryConnector
|
|
||||||
# from Kie_Hoanglv.prediction import predict
|
|
||||||
|
|
||||||
# c_connector = CeleryConnector()
|
|
||||||
# try:
|
|
||||||
# print(url)
|
|
||||||
# result = predict(str(url))
|
|
||||||
# hoadon = {"status": 200, "content": result, "message": "Success"}
|
|
||||||
# c_connector.process_invoice_result((rq_id, hoadon))
|
|
||||||
# return {"rq_id": rq_id}
|
|
||||||
|
|
||||||
# except Exception as e:
|
|
||||||
# print(e)
|
|
||||||
# hoadon = {"status": 404, "content": {}}
|
|
||||||
# c_connector.process_invoice_result((rq_id, hoadon))
|
|
||||||
# return {"rq_id": rq_id}
|
|
||||||
|
|
||||||
@app.task(name="process_invoice")
|
|
||||||
def process_invoice(rq_id, list_url):
|
|
||||||
from celery_worker.client_connector import CeleryConnector
|
|
||||||
from common.process_pdf import compile_output
|
|
||||||
|
|
||||||
c_connector = CeleryConnector()
|
|
||||||
try:
|
|
||||||
result = compile_output(list_url)
|
|
||||||
hoadon = {"status": 200, "content": result, "message": "Success"}
|
|
||||||
c_connector.process_invoice_result((rq_id, hoadon))
|
|
||||||
return {"rq_id": rq_id}
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
hoadon = {"status": 404, "content": {}}
|
|
||||||
c_connector.process_invoice_result((rq_id, hoadon))
|
|
||||||
return {"rq_id": rq_id}
|
|
||||||
|
|
||||||
|
|
||||||
@app.task(name="process_ocr_with_box")
|
|
||||||
def process_ocr_with_box(rq_id, list_url):
|
|
||||||
from celery_worker.client_connector import CeleryConnector
|
|
||||||
from common.process_pdf import compile_output_ocr_base
|
|
||||||
|
|
||||||
c_connector = CeleryConnector()
|
|
||||||
try:
|
|
||||||
result = compile_output_ocr_base(list_url)
|
|
||||||
result = {"status": 200, "content": result, "message": "Success"}
|
|
||||||
c_connector.process_ocr_with_box_result((rq_id, result))
|
|
||||||
return {"rq_id": rq_id}
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
result = {"status": 404, "content": {}}
|
|
||||||
c_connector.process_ocr_with_box_result((rq_id, result))
|
|
||||||
return {"rq_id": rq_id}
|
|
@ -1,8 +1,16 @@
|
|||||||
from celery_worker.worker_fi import app
|
from celery_worker.worker_fi import app
|
||||||
from celery_worker.client_connector_fi import CeleryConnector
|
from celery_worker.client_connector_fi import CeleryConnector
|
||||||
from common.process_pdf import compile_output_sbt
|
from common.process_pdf import compile_output_sbt
|
||||||
|
from .task_warpper import VerboseTask
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@app.task(name="process_fi_invoice")
|
@app.task(base=VerboseTask,name="process_fi_invoice")
|
||||||
def process_invoice(rq_id, list_url):
|
def process_invoice(rq_id, list_url):
|
||||||
from celery_worker.client_connector_fi import CeleryConnector
|
from celery_worker.client_connector_fi import CeleryConnector
|
||||||
from common.process_pdf import compile_output_fi
|
from common.process_pdf import compile_output_fi
|
||||||
@ -11,22 +19,22 @@ def process_invoice(rq_id, list_url):
|
|||||||
try:
|
try:
|
||||||
result = compile_output_fi(list_url)
|
result = compile_output_fi(list_url)
|
||||||
hoadon = {"status": 200, "content": result, "message": "Success"}
|
hoadon = {"status": 200, "content": result, "message": "Success"}
|
||||||
print(hoadon)
|
logger.info(hoadon)
|
||||||
c_connector.process_fi_invoice_result((rq_id, hoadon))
|
c_connector.process_fi_invoice_result((rq_id, hoadon))
|
||||||
return {"rq_id": rq_id}
|
return {"rq_id": rq_id}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
logger.info(e)
|
||||||
hoadon = {"status": 404, "content": {}}
|
hoadon = {"status": 404, "content": {}}
|
||||||
c_connector.process_fi_invoice_result((rq_id, hoadon))
|
c_connector.process_fi_invoice_result((rq_id, hoadon))
|
||||||
return {"rq_id": rq_id}
|
return {"rq_id": rq_id}
|
||||||
|
|
||||||
|
|
||||||
@app.task(name="process_sap_invoice")
|
@app.task(base=VerboseTask,name="process_sap_invoice")
|
||||||
def process_sap_invoice(rq_id, list_url):
|
def process_sap_invoice(rq_id, list_url):
|
||||||
from celery_worker.client_connector_fi import CeleryConnector
|
from celery_worker.client_connector_fi import CeleryConnector
|
||||||
from common.process_pdf import compile_output
|
from common.process_pdf import compile_output
|
||||||
|
|
||||||
print(list_url)
|
logger.info(list_url)
|
||||||
c_connector = CeleryConnector()
|
c_connector = CeleryConnector()
|
||||||
try:
|
try:
|
||||||
result = compile_output(list_url)
|
result = compile_output(list_url)
|
||||||
@ -34,12 +42,12 @@ def process_sap_invoice(rq_id, list_url):
|
|||||||
c_connector.process_sap_invoice_result((rq_id, hoadon))
|
c_connector.process_sap_invoice_result((rq_id, hoadon))
|
||||||
return {"rq_id": rq_id}
|
return {"rq_id": rq_id}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
logger.info(e)
|
||||||
hoadon = {"status": 404, "content": {}}
|
hoadon = {"status": 404, "content": {}}
|
||||||
c_connector.process_sap_invoice_result((rq_id, hoadon))
|
c_connector.process_sap_invoice_result((rq_id, hoadon))
|
||||||
return {"rq_id": rq_id}
|
return {"rq_id": rq_id}
|
||||||
|
|
||||||
@app.task(name="process_manulife_invoice")
|
@app.task(base=VerboseTask,name="process_manulife_invoice")
|
||||||
def process_manulife_invoice(rq_id, list_url):
|
def process_manulife_invoice(rq_id, list_url):
|
||||||
from celery_worker.client_connector_fi import CeleryConnector
|
from celery_worker.client_connector_fi import CeleryConnector
|
||||||
from common.process_pdf import compile_output_manulife
|
from common.process_pdf import compile_output_manulife
|
||||||
@ -48,16 +56,16 @@ def process_manulife_invoice(rq_id, list_url):
|
|||||||
try:
|
try:
|
||||||
result = compile_output_manulife(list_url)
|
result = compile_output_manulife(list_url)
|
||||||
hoadon = {"status": 200, "content": result, "message": "Success"}
|
hoadon = {"status": 200, "content": result, "message": "Success"}
|
||||||
print(hoadon)
|
logger.info(hoadon)
|
||||||
c_connector.process_manulife_invoice_result((rq_id, hoadon))
|
c_connector.process_manulife_invoice_result((rq_id, hoadon))
|
||||||
return {"rq_id": rq_id}
|
return {"rq_id": rq_id}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
logger.info(e)
|
||||||
hoadon = {"status": 404, "content": {}}
|
hoadon = {"status": 404, "content": {}}
|
||||||
c_connector.process_manulife_invoice_result((rq_id, hoadon))
|
c_connector.process_manulife_invoice_result((rq_id, hoadon))
|
||||||
return {"rq_id": rq_id}
|
return {"rq_id": rq_id}
|
||||||
|
|
||||||
@app.task(name="process_sbt_invoice")
|
@app.task(base=VerboseTask,name="process_sbt_invoice")
|
||||||
def process_sbt_invoice(rq_id, list_url, metadata):
|
def process_sbt_invoice(rq_id, list_url, metadata):
|
||||||
# TODO: simply returning 200 and 404 doesn't make any sense
|
# TODO: simply returning 200 and 404 doesn't make any sense
|
||||||
c_connector = CeleryConnector()
|
c_connector = CeleryConnector()
|
||||||
@ -65,12 +73,12 @@ def process_sbt_invoice(rq_id, list_url, metadata):
|
|||||||
result = compile_output_sbt(list_url, metadata)
|
result = compile_output_sbt(list_url, metadata)
|
||||||
metadata['ai_inference_profile'] = result.pop("inference_profile")
|
metadata['ai_inference_profile'] = result.pop("inference_profile")
|
||||||
hoadon = {"status": 200, "content": result, "message": "Success"}
|
hoadon = {"status": 200, "content": result, "message": "Success"}
|
||||||
print(hoadon)
|
logger.info(hoadon)
|
||||||
c_connector.process_sbt_invoice_result((rq_id, hoadon, metadata))
|
c_connector.process_sbt_invoice_result((rq_id, hoadon, metadata))
|
||||||
return {"rq_id": rq_id}
|
return {"rq_id": rq_id}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[ERROR]: Failed to extract invoice: {e}")
|
logger.info(f"[ERROR]: Failed to extract invoice: {e}")
|
||||||
print(e)
|
logger.info(e)
|
||||||
hoadon = {"status": 404, "content": {}}
|
hoadon = {"status": 404, "content": {}}
|
||||||
c_connector.process_sbt_invoice_result((rq_id, hoadon, metadata))
|
c_connector.process_sbt_invoice_result((rq_id, hoadon, metadata))
|
||||||
return {"rq_id": rq_id}
|
return {"rq_id": rq_id}
|
20
cope2n-ai-fi/celery_worker/task_warpper.py
Normal file
20
cope2n-ai-fi/celery_worker/task_warpper.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from celery import Task
|
||||||
|
from celery.utils.log import get_task_logger
|
||||||
|
from utils.logging.local_storage import get_current_trace_id, set_current_trace_id
|
||||||
|
logger = get_task_logger(__name__)
|
||||||
|
|
||||||
|
class VerboseTask(Task):
|
||||||
|
abstract = True
|
||||||
|
|
||||||
|
def on_failure(self, exc, task_id, args, kwargs, einfo):
|
||||||
|
# Task failed. What do you want to do?
|
||||||
|
logger.error(f'FAILURE: Task: {self.name} - {task_id} | Task raised an exception: {exc}')
|
||||||
|
|
||||||
|
def on_success(self, retval, task_id, args, kwargs):
|
||||||
|
logger.info(f"SUCCESS: Task: {self.name} - {task_id} | retval: {retval} | args: {args} | kwargs: {kwargs}")
|
||||||
|
|
||||||
|
def before_start(self, task_id, args, kwargs):
|
||||||
|
trace_id = args[-1]
|
||||||
|
args.pop(-1)
|
||||||
|
set_current_trace_id(trace_id)
|
||||||
|
logger.info(f"BEFORE_START: Task: {self.name} - {task_id} | args: {args} | kwargs: {kwargs}")
|
@ -1,41 +0,0 @@
|
|||||||
from celery import Celery
|
|
||||||
from kombu import Queue, Exchange
|
|
||||||
import environ
|
|
||||||
env = environ.Env(
|
|
||||||
DEBUG=(bool, False)
|
|
||||||
)
|
|
||||||
|
|
||||||
app: Celery = Celery(
|
|
||||||
"postman",
|
|
||||||
broker= env.str("CELERY_BROKER", "amqp://test:test@rabbitmq:5672"),
|
|
||||||
# backend="rpc://",
|
|
||||||
include=[
|
|
||||||
"celery_worker.mock_process_tasks",
|
|
||||||
],
|
|
||||||
broker_transport_options={'confirm_publish': False},
|
|
||||||
)
|
|
||||||
task_exchange = Exchange("default", type="direct")
|
|
||||||
task_create_missing_queues = False
|
|
||||||
app.conf.update(
|
|
||||||
{
|
|
||||||
"result_expires": 3600,
|
|
||||||
"task_queues": [
|
|
||||||
Queue("id_card"),
|
|
||||||
Queue("driver_license"),
|
|
||||||
Queue("invoice"),
|
|
||||||
Queue("ocr_with_box"),
|
|
||||||
Queue("template_matching"),
|
|
||||||
],
|
|
||||||
"task_routes": {
|
|
||||||
"process_id": {"queue": "id_card"},
|
|
||||||
"process_driver_license": {"queue": "driver_license"},
|
|
||||||
"process_invoice": {"queue": "invoice"},
|
|
||||||
"process_ocr_with_box": {"queue": "ocr_with_box"},
|
|
||||||
"process_template_matching": {"queue": "template_matching"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
argv = ["celery_worker.worker", "--loglevel=INFO", "--pool=solo"] # Window opts
|
|
||||||
app.worker_main(argv)
|
|
@ -1,6 +1,7 @@
|
|||||||
from celery import Celery
|
from celery import Celery
|
||||||
from kombu import Queue, Exchange
|
from kombu import Queue, Exchange
|
||||||
import environ
|
import environ
|
||||||
|
|
||||||
env = environ.Env(
|
env = environ.Env(
|
||||||
DEBUG=(bool, False)
|
DEBUG=(bool, False)
|
||||||
)
|
)
|
||||||
@ -13,6 +14,7 @@ app: Celery = Celery(
|
|||||||
],
|
],
|
||||||
broker_transport_options={'confirm_publish': False},
|
broker_transport_options={'confirm_publish': False},
|
||||||
)
|
)
|
||||||
|
|
||||||
task_exchange = Exchange("default", type="direct")
|
task_exchange = Exchange("default", type="direct")
|
||||||
task_create_missing_queues = False
|
task_create_missing_queues = False
|
||||||
app.conf.update(
|
app.conf.update(
|
||||||
|
@ -98,4 +98,3 @@ if __name__ == "__main__":
|
|||||||
image_path = "/mnt/ssd1T/tuanlv/PV2-2023/common/AnyKey_Value/visualize/test1/RedInvoice_WaterPurfier_Feb_PVI_829_0.jpg"
|
image_path = "/mnt/ssd1T/tuanlv/PV2-2023/common/AnyKey_Value/visualize/test1/RedInvoice_WaterPurfier_Feb_PVI_829_0.jpg"
|
||||||
save_dir = "/mnt/ssd1T/tuanlv/PV2-2023/common/AnyKey_Value/visualize/test1"
|
save_dir = "/mnt/ssd1T/tuanlv/PV2-2023/common/AnyKey_Value/visualize/test1"
|
||||||
vat_outputs = predict_image(image_path, save_dir, predictor, processor)
|
vat_outputs = predict_image(image_path, save_dir, predictor, processor)
|
||||||
print('[INFO] Done')
|
|
||||||
|
@ -7,6 +7,13 @@ sys.path.append('/mnt/ssd1T/tuanlv/02.KeyValueUnderstanding/') #TODO: ??????
|
|||||||
from lightning_modules.classifier_module import parse_initial_words, parse_subsequent_words, parse_relations
|
from lightning_modules.classifier_module import parse_initial_words, parse_subsequent_words, parse_relations
|
||||||
from model import get_model
|
from model import get_model
|
||||||
from utils import load_model_weight
|
from utils import load_model_weight
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class KVUPredictor:
|
class KVUPredictor:
|
||||||
@ -18,9 +25,9 @@ class KVUPredictor:
|
|||||||
self.dummy_idx = dummy_idx
|
self.dummy_idx = dummy_idx
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
|
||||||
print('[INFO] Loading Key-Value Understanding model ...')
|
logger.info('[INFO] Loading Key-Value Understanding model ...')
|
||||||
self.net, cfg, self.backbone_type = self._load_model(cfg_path, ckpt_path)
|
self.net, cfg, self.backbone_type = self._load_model(cfg_path, ckpt_path)
|
||||||
print("[INFO] Loaded model")
|
logger.info("[INFO] Loaded model")
|
||||||
|
|
||||||
if mode == 3:
|
if mode == 3:
|
||||||
self.max_window_count = cfg.train.max_window_count
|
self.max_window_count = cfg.train.max_window_count
|
||||||
@ -39,7 +46,7 @@ class KVUPredictor:
|
|||||||
cfg.stage = self.mode
|
cfg.stage = self.mode
|
||||||
backbone_type = cfg.model.backbone
|
backbone_type = cfg.model.backbone
|
||||||
|
|
||||||
print('[INFO] Checkpoint:', ckpt_path)
|
logger.info('[INFO] Checkpoint:', ckpt_path)
|
||||||
net = get_model(cfg)
|
net = get_model(cfg)
|
||||||
load_model_weight(net, ckpt_path)
|
load_model_weight(net, ckpt_path)
|
||||||
net.to('cuda')
|
net.to('cuda')
|
||||||
|
@ -6,7 +6,13 @@ from pytorch_lightning.callbacks import ModelCheckpoint
|
|||||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||||
from pytorch_lightning.plugins import DDPPlugin
|
from pytorch_lightning.plugins import DDPPlugin
|
||||||
from utils.ema_callbacks import EMA
|
from utils.ema_callbacks import EMA
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def _update_config(cfg):
|
def _update_config(cfg):
|
||||||
cfg.save_weight_dir = os.path.join(cfg.workspace, "checkpoints")
|
cfg.save_weight_dir = os.path.join(cfg.workspace, "checkpoints")
|
||||||
@ -14,7 +20,7 @@ def _update_config(cfg):
|
|||||||
|
|
||||||
# set per-gpu batch size
|
# set per-gpu batch size
|
||||||
num_devices = torch.cuda.device_count()
|
num_devices = torch.cuda.device_count()
|
||||||
print('No. devices:', num_devices)
|
logger.info('No. devices:', num_devices)
|
||||||
for mode in ["train", "val"]:
|
for mode in ["train", "val"]:
|
||||||
new_batch_size = cfg[mode].batch_size // num_devices
|
new_batch_size = cfg[mode].batch_size // num_devices
|
||||||
cfg[mode].batch_size = new_batch_size
|
cfg[mode].batch_size = new_batch_size
|
||||||
@ -89,15 +95,15 @@ def create_exp_dir(save_dir=''):
|
|||||||
if not os.path.exists(save_dir):
|
if not os.path.exists(save_dir):
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
else:
|
else:
|
||||||
print("DIR already existed.")
|
logger.info("DIR already existed.")
|
||||||
print('Experiment dir : {}'.format(save_dir))
|
logger.info('Experiment dir : {}'.format(save_dir))
|
||||||
|
|
||||||
def create_dir(save_dir=''):
|
def create_dir(save_dir=''):
|
||||||
if not os.path.exists(save_dir):
|
if not os.path.exists(save_dir):
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
else:
|
else:
|
||||||
print("DIR already existed.")
|
logger.info("DIR already existed.")
|
||||||
print('Save dir : {}'.format(save_dir))
|
logger.info('Save dir : {}'.format(save_dir))
|
||||||
|
|
||||||
def load_checkpoint(ckpt_path, model, key_include):
|
def load_checkpoint(ckpt_path, model, key_include):
|
||||||
assert os.path.exists(ckpt_path) == True, f"Ckpt path at {ckpt_path} not exist!"
|
assert os.path.exists(ckpt_path) == True, f"Ckpt path at {ckpt_path} not exist!"
|
||||||
@ -109,7 +115,7 @@ def load_checkpoint(ckpt_path, model, key_include):
|
|||||||
state_dict[key[4:].replace(key_include + '.', "")] = state_dict[key] # remove net.something.
|
state_dict[key[4:].replace(key_include + '.', "")] = state_dict[key] # remove net.something.
|
||||||
del state_dict[key]
|
del state_dict[key]
|
||||||
model.load_state_dict(state_dict, strict=True)
|
model.load_state_dict(state_dict, strict=True)
|
||||||
print(f"Load checkpoint at {ckpt_path}")
|
logger.info(f"Load checkpoint at {ckpt_path}")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def load_model_weight(net, pretrained_model_file):
|
def load_model_weight(net, pretrained_model_file):
|
||||||
|
@ -6,14 +6,21 @@ import sys
|
|||||||
# from src.ocr import OcrEngine
|
# from src.ocr import OcrEngine
|
||||||
sys.path.append('/home/thucpd/thucpd/git/PV2-2023/kie-invoice/components/prediction') # TODO: ??????
|
sys.path.append('/home/thucpd/thucpd/git/PV2-2023/kie-invoice/components/prediction') # TODO: ??????
|
||||||
import serve_model
|
import serve_model
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# def load_ocr_engine() -> OcrEngine:
|
# def load_ocr_engine() -> OcrEngine:
|
||||||
def load_ocr_engine() -> OcrEngine:
|
def load_ocr_engine() -> OcrEngine:
|
||||||
print("[INFO] Loading engine...")
|
logger.info("[INFO] Loading engine...")
|
||||||
# engine = OcrEngine()
|
# engine = OcrEngine()
|
||||||
engine = serve_model.engine
|
engine = serve_model.engine
|
||||||
print("[INFO] Engine loaded")
|
logger.info("[INFO] Engine loaded")
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
def process_img(img: Union[str, np.ndarray], save_dir_or_path: str, engine: OcrEngine, export_img: bool) -> None:
|
def process_img(img: Union[str, np.ndarray], save_dir_or_path: str, engine: OcrEngine, export_img: bool) -> None:
|
||||||
|
@ -10,25 +10,31 @@ from pdf2image import convert_from_path
|
|||||||
from dicttoxml import dicttoxml
|
from dicttoxml import dicttoxml
|
||||||
from word_preprocess import vat_standardizer, get_string, ap_standardizer, post_process_for_item
|
from word_preprocess import vat_standardizer, get_string, ap_standardizer, post_process_for_item
|
||||||
from utils.kvu_dictionary import vat_dictionary, ap_dictionary
|
from utils.kvu_dictionary import vat_dictionary, ap_dictionary
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def create_dir(save_dir=''):
|
def create_dir(save_dir=''):
|
||||||
if not os.path.exists(save_dir):
|
if not os.path.exists(save_dir):
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
else:
|
else:
|
||||||
print("DIR already existed.")
|
logger.info("DIR already existed.")
|
||||||
print('Save dir : {}'.format(save_dir))
|
logger.info('Save dir : {}'.format(save_dir))
|
||||||
|
|
||||||
def pdf2image(pdf_dir, save_dir):
|
def pdf2image(pdf_dir, save_dir):
|
||||||
pdf_files = glob.glob(f'{pdf_dir}/*.pdf')
|
pdf_files = glob.glob(f'{pdf_dir}/*.pdf')
|
||||||
print('No. pdf files:', len(pdf_files))
|
logger.info('No. pdf files:', len(pdf_files))
|
||||||
|
|
||||||
for file in tqdm(pdf_files):
|
for file in tqdm(pdf_files):
|
||||||
pages = convert_from_path(file, 500)
|
pages = convert_from_path(file, 500)
|
||||||
for i, page in enumerate(pages):
|
for i, page in enumerate(pages):
|
||||||
page.save(os.path.join(save_dir, os.path.basename(file).replace('.pdf', f'_{i}.jpg')), 'JPEG')
|
page.save(os.path.join(save_dir, os.path.basename(file).replace('.pdf', f'_{i}.jpg')), 'JPEG')
|
||||||
print('Done!!!')
|
logger.info('Done!!!')
|
||||||
|
|
||||||
def xyxy2xywh(bbox):
|
def xyxy2xywh(bbox):
|
||||||
return [
|
return [
|
||||||
@ -246,7 +252,7 @@ def matched_wordgroup_relations(word_groups:dict, lrelations: list) -> list:
|
|||||||
try:
|
try:
|
||||||
outputs.append([word_groups[wg_from], word_groups[wg_to]])
|
outputs.append([word_groups[wg_from], word_groups[wg_to]])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print('Not valid pair:', wg_from, wg_to)
|
logger.info('Not valid pair:', wg_from, wg_to)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@ -264,7 +270,7 @@ def export_kvu_outputs(file_path, lwords, class_words, lrelations, labels=['othe
|
|||||||
triplet_pairs = []
|
triplet_pairs = []
|
||||||
single_pairs = []
|
single_pairs = []
|
||||||
table = []
|
table = []
|
||||||
# print('key2values_relations', key2values_relations)
|
# logger.info('key2values_relations', key2values_relations)
|
||||||
for key_group_id, list_value_group_ids in key2values_relations.items():
|
for key_group_id, list_value_group_ids in key2values_relations.items():
|
||||||
if len(list_value_group_ids) == 0: continue
|
if len(list_value_group_ids) == 0: continue
|
||||||
elif len(list_value_group_ids) == 1:
|
elif len(list_value_group_ids) == 1:
|
||||||
@ -355,7 +361,7 @@ def get_vat_information(outputs):
|
|||||||
for pair in outputs['single']:
|
for pair in outputs['single']:
|
||||||
for raw_key_name, value in pair.items():
|
for raw_key_name, value in pair.items():
|
||||||
key_name, score, proceessed_text = vat_standardizer(raw_key_name, threshold=0.8, header=False)
|
key_name, score, proceessed_text = vat_standardizer(raw_key_name, threshold=0.8, header=False)
|
||||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||||
|
|
||||||
if key_name in list(single_pairs.keys()):
|
if key_name in list(single_pairs.keys()):
|
||||||
single_pairs[key_name].append({
|
single_pairs[key_name].append({
|
||||||
@ -369,7 +375,7 @@ def get_vat_information(outputs):
|
|||||||
for key, value_list in triplet.items():
|
for key, value_list in triplet.items():
|
||||||
if len(value_list) == 1:
|
if len(value_list) == 1:
|
||||||
key_name, score, proceessed_text = vat_standardizer(key, threshold=0.8, header=False)
|
key_name, score, proceessed_text = vat_standardizer(key, threshold=0.8, header=False)
|
||||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||||
|
|
||||||
if key_name in list(single_pairs.keys()):
|
if key_name in list(single_pairs.keys()):
|
||||||
single_pairs[key_name].append({
|
single_pairs[key_name].append({
|
||||||
@ -381,7 +387,7 @@ def get_vat_information(outputs):
|
|||||||
|
|
||||||
for pair in value_list:
|
for pair in value_list:
|
||||||
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
|
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
|
||||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||||
|
|
||||||
if key_name in list(single_pairs.keys()):
|
if key_name in list(single_pairs.keys()):
|
||||||
single_pairs[key_name].append({
|
single_pairs[key_name].append({
|
||||||
@ -394,7 +400,7 @@ def get_vat_information(outputs):
|
|||||||
for table_row in outputs['table']:
|
for table_row in outputs['table']:
|
||||||
for pair in table_row:
|
for pair in table_row:
|
||||||
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
|
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
|
||||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
# logger.info(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||||
|
|
||||||
if key_name in list(single_pairs.keys()):
|
if key_name in list(single_pairs.keys()):
|
||||||
single_pairs[key_name].append({
|
single_pairs[key_name].append({
|
||||||
@ -461,7 +467,7 @@ def get_ap_table_information(outputs):
|
|||||||
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
|
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
|
||||||
for cell in single_item:
|
for cell in single_item:
|
||||||
header_name, score, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=True)
|
header_name, score, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=True)
|
||||||
# print(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}")
|
# logger.info(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}")
|
||||||
if header_name in list(item.keys()):
|
if header_name in list(item.keys()):
|
||||||
item[header_name].append({
|
item[header_name].append({
|
||||||
'content': cell['text'],
|
'content': cell['text'],
|
||||||
@ -515,7 +521,7 @@ def get_ap_information(outputs):
|
|||||||
for pair in outputs['single']:
|
for pair in outputs['single']:
|
||||||
for key_name, value in pair.items():
|
for key_name, value in pair.items():
|
||||||
key_name, score, proceessed_text = ap_standardizer(key_name, threshold=0.8, header=False)
|
key_name, score, proceessed_text = ap_standardizer(key_name, threshold=0.8, header=False)
|
||||||
# print(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
# logger.info(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||||
|
|
||||||
if key_name in list(single_pairs):
|
if key_name in list(single_pairs):
|
||||||
single_pairs[key_name].append({
|
single_pairs[key_name].append({
|
||||||
|
@ -5,6 +5,13 @@ import copy
|
|||||||
from utils.kvu_dictionary import vat_dictionary, ap_dictionary, DKVU2XML
|
from utils.kvu_dictionary import vat_dictionary, ap_dictionary, DKVU2XML
|
||||||
nltk.download('words')
|
nltk.download('words')
|
||||||
words = set(nltk.corpus.words.words())
|
words = set(nltk.corpus.words.words())
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
s1 = u'ÀÁÂÃÈÉÊÌÍÒÓÔÕÙÚÝàáâãèéêìíòóôõùúýĂăĐđĨĩŨũƠơƯưẠạẢảẤấẦầẨẩẪẫẬậẮắẰằẲẳẴẵẶặẸẹẺẻẼẽẾếỀềỂểỄễỆệỈỉỊịỌọỎỏỐốỒồỔổỖỗỘộỚớỜờỞởỠỡỢợỤụỦủỨứỪừỬửỮữỰựỲỳỴỵỶỷỸỹ'
|
s1 = u'ÀÁÂÃÈÉÊÌÍÒÓÔÕÙÚÝàáâãèéêìíòóôõùúýĂăĐđĨĩŨũƠơƯưẠạẢảẤấẦầẨẩẪẫẬậẮắẰằẲẳẴẵẶặẸẹẺẻẼẽẾếỀềỂểỄễỆệỈỉỊịỌọỎỏỐốỒồỔổỖỗỘộỚớỜờỞởỠỡỢợỤụỦủỨứỪừỬửỮữỰựỲỳỴỵỶỷỸỹ'
|
||||||
s0 = u'AAAAEEEIIOOOOUUYaaaaeeeiioooouuyAaDdIiUuOoUuAaAaAaAaAaAaAaAaAaAaAaAaEeEeEeEeEeEeEeEeIiIiOoOoOoOoOoOoOoOoOoOoOoOoUuUuUuUuUuUuUuYyYyYyYy'
|
s0 = u'AAAAEEEIIOOOOUUYaaaaeeeiioooouuyAaDdIiUuOoUuAaAaAaAaAaAaAaAaAaAaAaAaEeEeEeEeEeEeEeEeIiIiOoOoOoOoOoOoOoOoOoOoOoOoUuUuUuUuUuUuUuYyYyYyYy'
|
||||||
@ -31,7 +38,7 @@ def remove_punctuation(text):
|
|||||||
|
|
||||||
def remove_accents(input_str, s0, s1):
|
def remove_accents(input_str, s0, s1):
|
||||||
s = ''
|
s = ''
|
||||||
# print input_str.encode('utf-8')
|
# logger.info input_str.encode('utf-8')
|
||||||
for c in input_str:
|
for c in input_str:
|
||||||
if c in s1:
|
if c in s1:
|
||||||
s += s0[s1.index(c)]
|
s += s0[s1.index(c)]
|
||||||
@ -159,7 +166,7 @@ def post_process_for_item(item: dict) -> dict:
|
|||||||
elif mis_key[0] == check_keys[2]:
|
elif mis_key[0] == check_keys[2]:
|
||||||
item[mis_key[0]] = (convert_format_number(item[check_keys[0]]) * convert_format_number(item[check_keys[1]])).__str__()
|
item[mis_key[0]] = (convert_format_number(item[check_keys[0]]) * convert_format_number(item[check_keys[1]])).__str__()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Cannot post process this item with error:", e)
|
logger.info("Cannot post process this item with error:", e)
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,12 @@
|
|||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
ET.register_namespace('', "http://www.w3.org/2000/09/xmldsig#")
|
ET.register_namespace('', "http://www.w3.org/2000/09/xmldsig#")
|
||||||
|
|
||||||
|
|
||||||
@ -124,7 +131,7 @@ def replace_xml_values(xml_str, replacement_dict):
|
|||||||
formatted_date = date_obj.strftime("%Y-%m-%d")
|
formatted_date = date_obj.strftime("%Y-%m-%d")
|
||||||
nlap_element.text = formatted_date
|
nlap_element.text = formatted_date
|
||||||
except ValueError:
|
except ValueError:
|
||||||
print(f"Invalid date format for {key}: {value}")
|
logger.info(f"Invalid date format for {key}: {value}")
|
||||||
nlap_element.text = value
|
nlap_element.text = value
|
||||||
else:
|
else:
|
||||||
element = root.find(f".//{key}")
|
element = root.find(f".//{key}")
|
||||||
@ -133,7 +140,7 @@ def replace_xml_values(xml_str, replacement_dict):
|
|||||||
ET.register_namespace("", "http://www.w3.org/2000/09/xmldsig#")
|
ET.register_namespace("", "http://www.w3.org/2000/09/xmldsig#")
|
||||||
return ET.tostring(root, encoding="unicode")
|
return ET.tostring(root, encoding="unicode")
|
||||||
except ET.ParseError as e:
|
except ET.ParseError as e:
|
||||||
print(f"Error parsing XML: {e}")
|
logger.info(f"Error parsing XML: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,6 +5,13 @@ det_ckpt = "yolox-s-general-text-pretrain-20221226"
|
|||||||
cls_ckpt = "satrn-lite-general-pretrain-20230106"
|
cls_ckpt = "satrn-lite-general-pretrain-20230106"
|
||||||
|
|
||||||
engine = OcrEngineForYoloX_ID_Driving(det_ckpt, cls_ckpt)
|
engine = OcrEngineForYoloX_ID_Driving(det_ckpt, cls_ckpt)
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def ocr_predict(image):
|
def ocr_predict(image):
|
||||||
@ -22,7 +29,7 @@ def ocr_predict(image):
|
|||||||
list_lines, _ = words_to_lines(lWords)
|
list_lines, _ = words_to_lines(lWords)
|
||||||
return list_lines
|
return list_lines
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
print(e)
|
logger.info(e)
|
||||||
list_lines = []
|
list_lines = []
|
||||||
return list_lines
|
return list_lines
|
||||||
|
|
||||||
|
@ -3,7 +3,13 @@ from datetime import datetime
|
|||||||
from sklearn.metrics import classification_report
|
from sklearn.metrics import classification_report
|
||||||
from common.utils.utils import read_json
|
from common.utils.utils import read_json
|
||||||
from underthesea import word_tokenize
|
from underthesea import word_tokenize
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class DatetimeCorrector:
|
class DatetimeCorrector:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -92,8 +98,6 @@ class DatetimeCorrector:
|
|||||||
for k, d in data.items():
|
for k, d in data.items():
|
||||||
if k in lexcludes:
|
if k in lexcludes:
|
||||||
continue
|
continue
|
||||||
if k == "inv_SDV_215":
|
|
||||||
print("debugging")
|
|
||||||
pred = DatetimeCorrector.correct(d["pred"])
|
pred = DatetimeCorrector.correct(d["pred"])
|
||||||
label = DatetimeCorrector.correct(d["label"])
|
label = DatetimeCorrector.correct(d["label"])
|
||||||
ddata[k] = {}
|
ddata[k] = {}
|
||||||
@ -103,11 +107,8 @@ class DatetimeCorrector:
|
|||||||
ddata[k]["Post-processed"] = pred
|
ddata[k]["Post-processed"] = pred
|
||||||
y_pred.append(pred == label)
|
y_pred.append(pred == label)
|
||||||
y_true.append(1)
|
y_true.append(1)
|
||||||
if k == "invoice_1219_000":
|
|
||||||
print("\n", k, '-' * 50)
|
logger.info(classification_report(y_true, y_pred))
|
||||||
print(pred, "------", d["pred"])
|
|
||||||
print(label, "------", d["label"])
|
|
||||||
print(classification_report(y_true, y_pred))
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
df = pd.DataFrame.from_dict(ddata, orient="index")
|
df = pd.DataFrame.from_dict(ddata, orient="index")
|
||||||
df.to_csv(f"result/datetime_post_processed_{type_column}.csv")
|
df.to_csv(f"result/datetime_post_processed_{type_column}.csv")
|
@ -11,6 +11,13 @@ from common.utils_kvu.split_docs import split_docs, merge_sbt_output
|
|||||||
# from api.Kie_Invoice_AP.prediction_fi import predict_fi
|
# from api.Kie_Invoice_AP.prediction_fi import predict_fi
|
||||||
# from api.manulife.predict_manulife import predict as predict_manulife
|
# from api.manulife.predict_manulife import predict as predict_manulife
|
||||||
from api.sdsap_sbt.prediction_sbt import predict as predict_sbt
|
from api.sdsap_sbt.prediction_sbt import predict as predict_sbt
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
os.environ['PYTHONPATH'] = '/home/thucpd/thucpd/cope2n-ai/cope2n-ai/'
|
os.environ['PYTHONPATH'] = '/home/thucpd/thucpd/cope2n-ai/cope2n-ai/'
|
||||||
|
|
||||||
@ -188,11 +195,11 @@ def compile_output_manulife(list_url):
|
|||||||
outputs = []
|
outputs = []
|
||||||
for page in list_url:
|
for page in list_url:
|
||||||
output_model = predict_manulife(page['page_number'], page['file_url']) # gotta be predict_manulife(), for the time being, this function is not avaible, we just leave a dummy function here instead
|
output_model = predict_manulife(page['page_number'], page['file_url']) # gotta be predict_manulife(), for the time being, this function is not avaible, we just leave a dummy function here instead
|
||||||
print("output_model", output_model)
|
logger.info("output_model", output_model)
|
||||||
outputs.append(output_model)
|
outputs.append(output_model)
|
||||||
print("outputs", outputs)
|
logger.info("outputs", outputs)
|
||||||
documents = split_docs(outputs)
|
documents = split_docs(outputs)
|
||||||
print("documents", documents)
|
logger.info("documents", documents)
|
||||||
results = {
|
results = {
|
||||||
"total_pages": len(list_url),
|
"total_pages": len(list_url),
|
||||||
"ocr_num_pages": len(list_url),
|
"ocr_num_pages": len(list_url),
|
||||||
|
@ -1,5 +1,12 @@
|
|||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# tuplify
|
# tuplify
|
||||||
def tup(point):
|
def tup(point):
|
||||||
@ -85,7 +92,7 @@ while not finished:
|
|||||||
finished = True
|
finished = True
|
||||||
|
|
||||||
# check progress
|
# check progress
|
||||||
print("Len Boxes: " + str(len(boxes)))
|
logger.info("Len Boxes: " + str(len(boxes)))
|
||||||
|
|
||||||
# draw boxes # comment this section out to run faster
|
# draw boxes # comment this section out to run faster
|
||||||
copy = np.copy(orig)
|
copy = np.copy(orig)
|
||||||
|
@ -4,6 +4,13 @@ from sdsvtr import StandaloneSATRNRunner
|
|||||||
from sdsvtd import StandaloneYOLOXRunner
|
from sdsvtd import StandaloneYOLOXRunner
|
||||||
import urllib
|
import urllib
|
||||||
import cv2
|
import cv2
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class YoloX:
|
class YoloX:
|
||||||
@ -50,8 +57,8 @@ class OcrEngineForYoloX_Invoice:
|
|||||||
lbboxes.append(bbox_)
|
lbboxes.append(bbox_)
|
||||||
lcropped_img.append(crop_img)
|
lcropped_img.append(crop_img)
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
print(e)
|
logger.info(e)
|
||||||
print(f"[ERROR]: Skipping invalid bbox in image")
|
logger.info(f"[ERROR]: Skipping invalid bbox in image")
|
||||||
lwords, _ = self.cls.inference(lcropped_img)
|
lwords, _ = self.cls.inference(lcropped_img)
|
||||||
return lbboxes, lwords
|
return lbboxes, lwords
|
||||||
|
|
||||||
@ -72,6 +79,6 @@ class OcrEngineForYoloX_ID_Driving:
|
|||||||
lbboxes.append(bbox_)
|
lbboxes.append(bbox_)
|
||||||
lcropped_img.append(crop_img)
|
lcropped_img.append(crop_img)
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
print(f"[ERROR]: Skipping invalid bbox image in ")
|
logger.info(f"[ERROR]: Skipping invalid bbox image in ")
|
||||||
lwords, _ = self.cls.inference(lcropped_img)
|
lwords, _ = self.cls.inference(lcropped_img)
|
||||||
return lbboxes, lwords
|
return lbboxes, lwords
|
||||||
|
@ -5,7 +5,13 @@ from xml.dom.expatbuilder import parseString
|
|||||||
from lxml.etree import Element, tostring, SubElement
|
from lxml.etree import Element, tostring, SubElement
|
||||||
import tqdm
|
import tqdm
|
||||||
from common.utils.global_variables import *
|
from common.utils.global_variables import *
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def boxes_to_xml(boxes_lst, xml_pth, img_pth=""):
|
def boxes_to_xml(boxes_lst, xml_pth, img_pth=""):
|
||||||
"""_summary_
|
"""_summary_
|
||||||
@ -75,7 +81,7 @@ def boxes_to_xml(boxes_lst, xml_pth, img_pth=""):
|
|||||||
node_ymax = SubElement(node_bndbox, "ymax")
|
node_ymax = SubElement(node_bndbox, "ymax")
|
||||||
node_ymax.text = bottom
|
node_ymax.text = bottom
|
||||||
|
|
||||||
xml = tostring(node_root, pretty_print=True)
|
xml = tostring(node_root, pretty_logger.info=True)
|
||||||
dom = parseString(xml)
|
dom = parseString(xml)
|
||||||
with open(xml_pth, "w+", encoding="utf-8") as f:
|
with open(xml_pth, "w+", encoding="utf-8") as f:
|
||||||
dom.writexml(f, indent="\t", addindent="\t", encoding="utf-8")
|
dom.writexml(f, indent="\t", addindent="\t", encoding="utf-8")
|
||||||
@ -105,7 +111,7 @@ def check_iou(box1: Box, box2: Box, threshold=0.9):
|
|||||||
ymax_intersect * ymin_intersect
|
ymax_intersect * ymin_intersect
|
||||||
)
|
)
|
||||||
union = area1 + area2 - area_intersect
|
union = area1 + area2 - area_intersect
|
||||||
print(union)
|
logger.info(union)
|
||||||
iou = area_intersect / area1
|
iou = area_intersect / area1
|
||||||
if iou > threshold:
|
if iou > threshold:
|
||||||
return True
|
return True
|
||||||
|
@ -1,5 +1,12 @@
|
|||||||
from builtins import dict
|
from builtins import dict
|
||||||
from common.utils.global_variables import *
|
from common.utils.global_variables import *
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from utils.logging.logging import LOGGER_CONFIG
|
||||||
|
# Load the logging configuration
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG)
|
||||||
|
# Get the logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
MIN_IOU_HEIGHT = 0.7
|
MIN_IOU_HEIGHT = 0.7
|
||||||
MIN_WIDTH_LINE_RATIO = 0.05
|
MIN_WIDTH_LINE_RATIO = 0.05
|
||||||
@ -62,7 +69,7 @@ class Word_group:
|
|||||||
if word.text != "✪":
|
if word.text != "✪":
|
||||||
for w in self.list_words:
|
for w in self.list_words:
|
||||||
if word.word_id == w.word_id:
|
if word.word_id == w.word_id:
|
||||||
print("Word id collision")
|
logger.info("Word id collision")
|
||||||
return False
|
return False
|
||||||
word.word_group_id = self.word_group_id #
|
word.word_group_id = self.word_group_id #
|
||||||
word.line_id = self.line_id
|
word.line_id = self.line_id
|
||||||
@ -120,7 +127,7 @@ class Line:
|
|||||||
if word_group.list_words is not None:
|
if word_group.list_words is not None:
|
||||||
for wg in self.list_word_groups:
|
for wg in self.list_word_groups:
|
||||||
if word_group.word_group_id == wg.word_group_id:
|
if word_group.word_group_id == wg.word_group_id:
|
||||||
print("Word_group id collision")
|
logger.info("Word_group id collision")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self.list_word_groups.append(word_group)
|
self.list_word_groups.append(word_group)
|
||||||
@ -204,7 +211,7 @@ class Paragraph:
|
|||||||
if line.list_word_groups is not None:
|
if line.list_word_groups is not None:
|
||||||
for l in self.list_lines:
|
for l in self.list_lines:
|
||||||
if line.line_id == l.line_id:
|
if line.line_id == l.line_id:
|
||||||
print("Line id collision")
|
logger.info("Line id collision")
|
||||||
return False
|
return False
|
||||||
for i in range(len(line.list_word_groups)):
|
for i in range(len(line.list_word_groups)):
|
||||||
line.list_word_groups[
|
line.list_word_groups[
|
||||||
@ -288,7 +295,7 @@ def prepare_line(words):
|
|||||||
new_line.merge_word(word)
|
new_line.merge_word(word)
|
||||||
lines.append(new_line)
|
lines.append(new_line)
|
||||||
|
|
||||||
# print(len(lines))
|
# logger.info(len(lines))
|
||||||
# sort line from top to bottom according top coordinate
|
# sort line from top to bottom according top coordinate
|
||||||
lines.sort(key=lambda x: x.boundingbox[1])
|
lines.sort(key=lambda x: x.boundingbox[1])
|
||||||
return lines
|
return lines
|
||||||
@ -381,7 +388,7 @@ def words_to_lines(words, check_special_lines=True): # words is list of Word in
|
|||||||
# sort word by top
|
# sort word by top
|
||||||
words.sort(key=lambda x: (x.boundingbox[1], x.boundingbox[0]))
|
words.sort(key=lambda x: (x.boundingbox[1], x.boundingbox[0]))
|
||||||
number_of_word = len(words)
|
number_of_word = len(words)
|
||||||
# print(number_of_word)
|
# logger.info(number_of_word)
|
||||||
# sort list words to list lines, which have not contained word_group yet
|
# sort list words to list lines, which have not contained word_group yet
|
||||||
lines = prepare_line(words)
|
lines = prepare_line(words)
|
||||||
|
|
||||||
@ -402,7 +409,7 @@ def near(word_group1: Word_group, word_group2: Word_group):
|
|||||||
if overlap > 0:
|
if overlap > 0:
|
||||||
return True
|
return True
|
||||||
if abs(overlap / min_height) < 1.5:
|
if abs(overlap / min_height) < 1.5:
|
||||||
print("near enough", abs(overlap / min_height), overlap, min_height)
|
logger.info("near enough", abs(overlap / min_height), overlap, min_height)
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -102,8 +102,6 @@ def merge_sbt_output(loutputs):
|
|||||||
})
|
})
|
||||||
return output
|
return output
|
||||||
|
|
||||||
print("concat outputs: \n", loutputs)
|
|
||||||
|
|
||||||
merged_output = []
|
merged_output = []
|
||||||
combined_output = {"retailername": None,
|
combined_output = {"retailername": None,
|
||||||
"sold_to_party": None,
|
"sold_to_party": None,
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit a3f2fea0154fb9098492c834155338fc47dc1527
|
Subproject commit be37541e48bcf2045be3e375319fdb69aa8bcef0
|
0
cope2n-ai-fi/utils/__init__.py
Normal file
0
cope2n-ai-fi/utils/__init__.py
Normal file
0
cope2n-ai-fi/utils/logging/__init___.py
Normal file
0
cope2n-ai-fi/utils/logging/__init___.py
Normal file
15
cope2n-ai-fi/utils/logging/local_storage.py
Normal file
15
cope2n-ai-fi/utils/logging/local_storage.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from threading import local
|
||||||
|
|
||||||
|
_thread_locals = local()
|
||||||
|
|
||||||
|
def get_current_request():
|
||||||
|
return getattr(_thread_locals, 'request', None)
|
||||||
|
|
||||||
|
def set_current_request(request):
|
||||||
|
_thread_locals.request = request
|
||||||
|
|
||||||
|
def set_current_trace_id(trace_id):
|
||||||
|
_thread_locals.trace_id = trace_id
|
||||||
|
|
||||||
|
def get_current_trace_id():
|
||||||
|
return getattr(_thread_locals, 'trace_id', None)
|
61
cope2n-ai-fi/utils/logging/logging.py
Normal file
61
cope2n-ai-fi/utils/logging/logging.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from .local_storage import set_current_trace_id, get_current_trace_id
|
||||||
|
|
||||||
|
class TraceIDLogFilter(logging.Filter):
|
||||||
|
def filter(self, record):
|
||||||
|
trace_id = get_current_trace_id()
|
||||||
|
record.trace_id = trace_id
|
||||||
|
return True
|
||||||
|
|
||||||
|
LOG_ROOT = os.getenv("LOG_ROOT", "/home/tuanlv/workspace/02-KVU/sdsvkvu/logs")
|
||||||
|
|
||||||
|
LOGGER_CONFIG = {
|
||||||
|
"version": 1,
|
||||||
|
"formatters": {
|
||||||
|
"default": {
|
||||||
|
"format": "%(asctime)s - %(name)s - %(levelname)s - %(trace_id)s - %(message)s"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"filters": {
|
||||||
|
"trace_id": {
|
||||||
|
"()": TraceIDLogFilter
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"handlers": {
|
||||||
|
'console': {
|
||||||
|
'class': 'logging.StreamHandler',
|
||||||
|
'formatter': 'default',
|
||||||
|
'filters': ['trace_id'],
|
||||||
|
},
|
||||||
|
"file_handler": {
|
||||||
|
"class": "logging.handlers.TimedRotatingFileHandler",
|
||||||
|
"filename": f"{LOG_ROOT}/sbt_idp_AI.log",
|
||||||
|
"level": "DEBUG",
|
||||||
|
"formatter": "default",
|
||||||
|
"filters": ["trace_id"],
|
||||||
|
"when": "midnight",
|
||||||
|
"interval": 1,
|
||||||
|
'backupCount': 10,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"loggers": {
|
||||||
|
"sdsvkvu": {
|
||||||
|
"level": "DEBUG",
|
||||||
|
"handlers": ["console", "file_handler"],
|
||||||
|
},
|
||||||
|
'': {
|
||||||
|
'handlers': ['console', 'file_handler'],
|
||||||
|
'level': 'INFO',
|
||||||
|
},
|
||||||
|
'django': {
|
||||||
|
'handlers': ['console', 'file_handler'],
|
||||||
|
'level': 'INFO',
|
||||||
|
},
|
||||||
|
'celery': {
|
||||||
|
'handlers': ['console', 'file_handler'],
|
||||||
|
'level': 'DEBUG',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
@ -282,6 +282,7 @@ LOGGING = {
|
|||||||
'console': {
|
'console': {
|
||||||
'class': 'logging.StreamHandler',
|
'class': 'logging.StreamHandler',
|
||||||
'formatter': 'verbose',
|
'formatter': 'verbose',
|
||||||
|
'filters': ['trace_id'],
|
||||||
},
|
},
|
||||||
'file': {
|
'file': {
|
||||||
"class": 'logging.handlers.TimedRotatingFileHandler',
|
"class": 'logging.handlers.TimedRotatingFileHandler',
|
||||||
@ -290,6 +291,7 @@ LOGGING = {
|
|||||||
"interval": 1,
|
"interval": 1,
|
||||||
'backupCount': 10,
|
'backupCount': 10,
|
||||||
'formatter': 'verbose',
|
'formatter': 'verbose',
|
||||||
|
'filters': ['trace_id'],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
'loggers': {
|
'loggers': {
|
||||||
|
@ -2,7 +2,7 @@ from celery import Celery
|
|||||||
|
|
||||||
from fwd import settings
|
from fwd import settings
|
||||||
from fwd_api.exception.exceptions import GeneralException
|
from fwd_api.exception.exceptions import GeneralException
|
||||||
from fwd_api.middleware.local_storage import get_current_request
|
from fwd_api.middleware.local_storage import get_current_trace_id
|
||||||
from kombu.utils.uuid import uuid
|
from kombu.utils.uuid import uuid
|
||||||
from celery.utils.log import get_task_logger
|
from celery.utils.log import get_task_logger
|
||||||
logger = get_task_logger(__name__)
|
logger = get_task_logger(__name__)
|
||||||
@ -128,9 +128,9 @@ class CeleryConnector:
|
|||||||
def send_task(self, name=None, args=None, countdown=None):
|
def send_task(self, name=None, args=None, countdown=None):
|
||||||
if name not in self.task_routes or 'queue' not in self.task_routes[name]:
|
if name not in self.task_routes or 'queue' not in self.task_routes[name]:
|
||||||
raise GeneralException("System")
|
raise GeneralException("System")
|
||||||
# task_id = args[0] + "_" + uuid()[:4] if isinstance(args, tuple) and is_it_an_index(args[0]) else uuid()
|
task_id = args[0] + "_" + uuid()[:4] if isinstance(args, tuple) and is_it_an_index(args[0]) else uuid()
|
||||||
request = get_current_request()
|
trace_id = get_current_trace_id()
|
||||||
task_id = request.META.get('X-Trace-ID', uuid()) + "_" + uuid()[:4] if request else uuid()
|
args += (trace_id,) # add trace_id to args then remove before start
|
||||||
logger.info(f"SEND task name: {name} - {task_id} | args: {args} | countdown: {countdown}")
|
logger.info(f"SEND task name: {name} - {task_id} | args: {args} | countdown: {countdown}")
|
||||||
return self.app.send_task(name, args, queue=self.task_routes[name]['queue'], expires=300, countdown=countdown, task_id=task_id)
|
return self.app.send_task(name, args, queue=self.task_routes[name]['queue'], expires=300, countdown=countdown, task_id=task_id)
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ from ..utils import process as ProcessUtil
|
|||||||
from ..utils import s3 as S3Util
|
from ..utils import s3 as S3Util
|
||||||
from ..utils.accuracy import validate_feedback_file
|
from ..utils.accuracy import validate_feedback_file
|
||||||
from fwd_api.constant.common import FileCategory
|
from fwd_api.constant.common import FileCategory
|
||||||
|
from fwd_api.middleware.local_storage import get_current_trace_id
|
||||||
import csv
|
import csv
|
||||||
import json
|
import json
|
||||||
import copy
|
import copy
|
||||||
@ -222,6 +223,8 @@ def process_pdf(rq_id, sub_id, p_type, user_id, files):
|
|||||||
file_meta["preprocessing_time"] = preprocessing_time
|
file_meta["preprocessing_time"] = preprocessing_time
|
||||||
file_meta["index_to_image_type"] = b_url["index_to_image_type"]
|
file_meta["index_to_image_type"] = b_url["index_to_image_type"]
|
||||||
file_meta["subsidiary"] = new_request.subsidiary
|
file_meta["subsidiary"] = new_request.subsidiary
|
||||||
|
file_meta["request_id"] = rq_id
|
||||||
|
file_meta["trace_id"] = get_current_trace_id()
|
||||||
to_queue.append((fractorized_request_id, sub_id, [b_url], user_id, p_type, file_meta))
|
to_queue.append((fractorized_request_id, sub_id, [b_url], user_id, p_type, file_meta))
|
||||||
|
|
||||||
# Send to next queue
|
# Send to next queue
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from celery import Task
|
from celery import Task
|
||||||
from celery.utils.log import get_task_logger
|
from celery.utils.log import get_task_logger
|
||||||
|
from fwd_api.middleware.local_storage import get_current_trace_id, set_current_trace_id
|
||||||
logger = get_task_logger(__name__)
|
logger = get_task_logger(__name__)
|
||||||
|
|
||||||
class VerboseTask(Task):
|
class VerboseTask(Task):
|
||||||
@ -13,4 +14,7 @@ class VerboseTask(Task):
|
|||||||
logger.info(f"SUCCESS: Task: {self.name} - {task_id} | retval: {retval} | args: {args} | kwargs: {kwargs}")
|
logger.info(f"SUCCESS: Task: {self.name} - {task_id} | retval: {retval} | args: {args} | kwargs: {kwargs}")
|
||||||
|
|
||||||
def before_start(self, task_id, args, kwargs):
|
def before_start(self, task_id, args, kwargs):
|
||||||
|
trace_id = args[-1]
|
||||||
|
args.pop(-1)
|
||||||
|
set_current_trace_id(trace_id)
|
||||||
logger.info(f"BEFORE_START: Task: {self.name} - {task_id} | args: {args} | kwargs: {kwargs}")
|
logger.info(f"BEFORE_START: Task: {self.name} - {task_id} | args: {args} | kwargs: {kwargs}")
|
@ -7,3 +7,9 @@ def get_current_request():
|
|||||||
|
|
||||||
def set_current_request(request):
|
def set_current_request(request):
|
||||||
_thread_locals.request = request
|
_thread_locals.request = request
|
||||||
|
|
||||||
|
def set_current_trace_id(trace_id):
|
||||||
|
_thread_locals.trace_id = trace_id
|
||||||
|
|
||||||
|
def get_current_trace_id():
|
||||||
|
return getattr(_thread_locals, 'trace_id', None)
|
@ -2,7 +2,7 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from django.utils.deprecation import MiddlewareMixin
|
from django.utils.deprecation import MiddlewareMixin
|
||||||
from .local_storage import set_current_request, get_current_request
|
from .local_storage import set_current_trace_id, get_current_trace_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -10,7 +10,7 @@ class LoggingMiddleware(MiddlewareMixin):
|
|||||||
def process_request(self, request):
|
def process_request(self, request):
|
||||||
trace_id = request.headers.get('X-Trace-ID', str(uuid.uuid4()))
|
trace_id = request.headers.get('X-Trace-ID', str(uuid.uuid4()))
|
||||||
request.META['X-Trace-ID'] = trace_id
|
request.META['X-Trace-ID'] = trace_id
|
||||||
set_current_request(request)
|
set_current_trace_id(trace_id)
|
||||||
|
|
||||||
request_body = ""
|
request_body = ""
|
||||||
content_type = request.headers.get("Content-Type", "")
|
content_type = request.headers.get("Content-Type", "")
|
||||||
@ -41,7 +41,6 @@ class LoggingMiddleware(MiddlewareMixin):
|
|||||||
|
|
||||||
class TraceIDLogFilter(logging.Filter):
|
class TraceIDLogFilter(logging.Filter):
|
||||||
def filter(self, record):
|
def filter(self, record):
|
||||||
request = get_current_request()
|
trace_id = get_current_trace_id()
|
||||||
trace_id = request.META.get('X-Trace-ID', 'unknown') if request else 'unknown'
|
|
||||||
record.trace_id = trace_id
|
record.trace_id = trace_id
|
||||||
return True
|
return True
|
@ -1 +1 @@
|
|||||||
Subproject commit a3f2fea0154fb9098492c834155338fc47dc1527
|
Subproject commit be37541e48bcf2045be3e375319fdb69aa8bcef0
|
@ -12,16 +12,16 @@ services:
|
|||||||
shm_size: 10gb
|
shm_size: 10gb
|
||||||
dockerfile: Dockerfile
|
dockerfile: Dockerfile
|
||||||
shm_size: 10gb
|
shm_size: 10gb
|
||||||
restart: always
|
|
||||||
networks:
|
networks:
|
||||||
- ctel-sbt
|
- ctel-sbt
|
||||||
privileged: true
|
privileged: true
|
||||||
image: sidp/cope2n-ai-fi-sbt:latest
|
image: sidp/cope2n-ai-fi-sbt:latest
|
||||||
# runtime: nvidia
|
# runtime: nvidia
|
||||||
environment:
|
environment:
|
||||||
|
- LOG_ROOT=${AI_LOG_ROOT}
|
||||||
- PYTHONPATH=${PYTHONPATH}:/workspace/cope2n-ai-fi # For import module
|
- PYTHONPATH=${PYTHONPATH}:/workspace/cope2n-ai-fi # For import module
|
||||||
- CELERY_BROKER=amqp://${RABBITMQ_DEFAULT_USER}:${RABBITMQ_DEFAULT_PASS}@rabbitmq-sbt:5672
|
- CELERY_BROKER=amqp://${RABBITMQ_DEFAULT_USER}:${RABBITMQ_DEFAULT_PASS}@rabbitmq-sbt:5672
|
||||||
# - CUDA_VISIBLE_DEVICES=0
|
- CUDA_VISIBLE_DEVICES=1
|
||||||
volumes:
|
volumes:
|
||||||
- ./cope2n-ai-fi:/workspace/cope2n-ai-fi # for dev container only
|
- ./cope2n-ai-fi:/workspace/cope2n-ai-fi # for dev container only
|
||||||
working_dir: /workspace/cope2n-ai-fi
|
working_dir: /workspace/cope2n-ai-fi
|
||||||
|
@ -15,6 +15,7 @@ services:
|
|||||||
- ctel-sbt
|
- ctel-sbt
|
||||||
privileged: true
|
privileged: true
|
||||||
environment:
|
environment:
|
||||||
|
- LOG_ROOT=${AI_LOG_ROOT}
|
||||||
- CELERY_BROKER=amqp://${RABBITMQ_DEFAULT_USER}:${RABBITMQ_DEFAULT_PASS}@rabbitmq-sbt:5672
|
- CELERY_BROKER=amqp://${RABBITMQ_DEFAULT_USER}:${RABBITMQ_DEFAULT_PASS}@rabbitmq-sbt:5672
|
||||||
working_dir: /workspace/cope2n-ai-fi
|
working_dir: /workspace/cope2n-ai-fi
|
||||||
command: bash run.sh
|
command: bash run.sh
|
||||||
|
Loading…
Reference in New Issue
Block a user