sbt-idp/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/predictor.py

237 lines
10 KiB
Python
Executable File

from omegaconf import OmegaConf
import os
import cv2
import torch
# from functions import get_colormap, visualize
import sys
sys.path.append('/mnt/ssd1T/tuanlv/02.KeyValueUnderstanding/') # TODO: ???????
from lightning_modules.classifier_module import parse_initial_words, parse_subsequent_words, parse_relations
from model import get_model
from utils import load_model_weight
import logging
import logging.config
from utils.logging.logging import LOGGER_CONFIG
# Load the logging configuration
logging.config.dictConfig(LOGGER_CONFIG)
# Get the logger
logger = logging.getLogger(__name__)
class KVUPredictor:
def __init__(self, configs, class_names, dummy_idx, mode=0):
cfg_path = configs['cfg']
ckpt_path = configs['ckpt']
self.class_names = class_names
self.dummy_idx = dummy_idx
self.mode = mode
logger.info('Loading Key-Value Understanding model ...')
self.net, cfg, self.backbone_type = self._load_model(cfg_path, ckpt_path)
logger.info("Loaded model")
if mode == 3:
self.max_window_count = cfg.train.max_window_count
self.window_size = cfg.train.window_size
self.slice_interval = 0
self.dummy_idx = dummy_idx * self.max_window_count
else:
self.slice_interval = cfg.train.slice_interval
self.window_size = cfg.train.max_num_words
self.device = 'cuda'
def _load_model(self, cfg_path, ckpt_path):
cfg = OmegaConf.load(cfg_path)
cfg.stage = self.mode
backbone_type = cfg.model.backbone
logger.info('Checkpoint:', ckpt_path)
net = get_model(cfg)
load_model_weight(net, ckpt_path)
net.to('cuda')
net.eval()
return net, cfg, backbone_type
def predict(self, input_sample):
if self.mode == 0:
if len(input_sample['words']) == 0:
return [], [], [], []
bbox, lwords, pr_class_words, pr_relations = self.combined_predict(input_sample)
return [bbox], [lwords], [pr_class_words], [pr_relations]
elif self.mode == 1:
if len(input_sample['documents']['words']) == 0:
return [], [], [], []
bbox, lwords, pr_class_words, pr_relations = self.cat_predict(input_sample)
return [bbox], [lwords], [pr_class_words], [pr_relations]
elif self.mode == 2:
if len(input_sample['windows'][0]['words']) == 0:
return [], [], [], []
bbox, lwords, pr_class_words, pr_relations = [], [], [], []
for window in input_sample['windows']:
_bbox, _lwords, _pr_class_words, _pr_relations = self.combined_predict(window)
bbox.append(_bbox)
lwords.append(_lwords)
pr_class_words.append(_pr_class_words)
pr_relations.append(_pr_relations)
return bbox, lwords, pr_class_words, pr_relations
elif self.mode == 3:
if len(input_sample["documents"]['words']) == 0:
return [], [], [], []
bbox, lwords, pr_class_words, pr_relations = self.doc_predict(input_sample)
return [bbox], [lwords], [pr_class_words], [pr_relations]
else:
raise ValueError(
f"Not supported mode: {self.mode}"
)
def doc_predict(self, input_sample):
lwords = input_sample['documents']['words']
for idx, window in enumerate(input_sample['windows']):
input_sample['windows'][idx] = {k: v.unsqueeze(0).to(self.device) for k, v in window.items() if k not in ('words', 'n_empty_windows')}
# input_sample['documents'] = {k: v.unsqueeze(0).to(self.device) for k, v in input_sample['documents'].items() if k not in ('words', 'n_empty_windows')}
with torch.no_grad():
head_outputs, _ = self.net(input_sample)
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()}
input_sample = input_sample['documents']
itc_outputs = head_outputs["itc_outputs"]
stc_outputs = head_outputs["stc_outputs"]
el_outputs = head_outputs["el_outputs"]
el_outputs_from_key = head_outputs["el_outputs_from_key"]
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)
box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0)
attention_mask = input_sample['attention_mask'].squeeze(0)
bbox = input_sample['bbox'].squeeze(0)
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
pr_class_words = parse_subsequent_words(
pr_stc_label, attention_mask, pr_init_words, self.dummy_idx
)
pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx)
pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx)
pr_relations = pr_relations_from_header | pr_relations_from_key
return bbox, lwords, pr_class_words, pr_relations
def combined_predict(self, input_sample):
lwords = input_sample['words']
input_sample = {k: v.unsqueeze(0) for k, v in input_sample.items() if k not in ('words', 'img_path')}
input_sample = {k: v.to(self.device) for k, v in input_sample.items()}
with torch.no_grad():
head_outputs, _ = self.net(input_sample)
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()}
input_sample = {k: v.detach().cpu() for k, v in input_sample.items()}
itc_outputs = head_outputs["itc_outputs"]
stc_outputs = head_outputs["stc_outputs"]
el_outputs = head_outputs["el_outputs"]
el_outputs_from_key = head_outputs["el_outputs_from_key"]
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)
box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0)
attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0)
bbox = input_sample['bbox'].squeeze(0)
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
pr_class_words = parse_subsequent_words(
pr_stc_label, attention_mask, pr_init_words, self.dummy_idx
)
pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx)
pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx)
pr_relations = pr_relations_from_header | pr_relations_from_key
return bbox, lwords, pr_class_words, pr_relations
def cat_predict(self, input_sample):
lwords = input_sample['documents']['words']
inputs = []
for window in input_sample['windows']:
inputs.append({k: v.unsqueeze(0).cuda() for k, v in window.items() if k not in ('words', 'img_path')})
input_sample['windows'] = inputs
with torch.no_grad():
head_outputs, _ = self.net(input_sample)
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items() if k not in ('embedding_tokens')}
itc_outputs = head_outputs["itc_outputs"]
stc_outputs = head_outputs["stc_outputs"]
el_outputs = head_outputs["el_outputs"]
el_outputs_from_key = head_outputs["el_outputs_from_key"]
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)
box_first_token_mask = input_sample['documents']['are_box_first_tokens']
attention_mask = input_sample['documents']['attention_mask_layoutxlm']
bbox = input_sample['documents']['bbox']
dummy_idx = input_sample['documents']['bbox'].shape[0]
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
pr_class_words = parse_subsequent_words(
pr_stc_label, attention_mask, pr_init_words, dummy_idx
)
pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, dummy_idx)
pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, dummy_idx)
pr_relations = pr_relations_from_header | pr_relations_from_key
return bbox, lwords, pr_class_words, pr_relations
def get_ground_truth_label(self, ground_truth):
# ground_truth = self.preprocessor.load_ground_truth(json_file)
gt_itc_label = ground_truth['itc_labels'].squeeze(0) # [1, 512] => [512]
gt_stc_label = ground_truth['stc_labels'].squeeze(0) # [1, 512] => [512]
gt_el_label = ground_truth['el_labels'].squeeze(0)
gt_el_label_from_key = ground_truth['el_labels_from_key'].squeeze(0)
lwords = ground_truth["words"]
box_first_token_mask = ground_truth['are_box_first_tokens'].squeeze(0)
attention_mask = ground_truth['attention_mask'].squeeze(0)
bbox = ground_truth['bbox'].squeeze(0)
gt_first_words = parse_initial_words(
gt_itc_label, box_first_token_mask, self.class_names
)
gt_class_words = parse_subsequent_words(
gt_stc_label, attention_mask, gt_first_words, self.dummy_idx
)
gt_relations_from_header = parse_relations(gt_el_label, box_first_token_mask, self.dummy_idx)
gt_relations_from_key = parse_relations(gt_el_label_from_key, box_first_token_mask, self.dummy_idx)
gt_relations = gt_relations_from_header | gt_relations_from_key
return bbox, lwords, gt_class_words, gt_relations