from omegaconf import OmegaConf import torch # from functions import get_colormap, visualize import sys sys.path.append('/mnt/ssd1T/tuanlv/02.KeyValueUnderstanding/') #TODO: ?????? from lightning_modules.classifier_module import parse_initial_words, parse_subsequent_words, parse_relations from model import get_model from utils import load_model_weight import logging import logging.config from utils.logging.logging import LOGGER_CONFIG # Load the logging configuration logging.config.dictConfig(LOGGER_CONFIG) # Get the logger logger = logging.getLogger(__name__) class KVUPredictor: def __init__(self, configs, class_names, dummy_idx, mode=0): cfg_path = configs['cfg'] ckpt_path = configs['ckpt'] self.class_names = class_names self.dummy_idx = dummy_idx self.mode = mode logger.info('[INFO] Loading Key-Value Understanding model ...') self.net, cfg, self.backbone_type = self._load_model(cfg_path, ckpt_path) logger.info("[INFO] Loaded model") if mode == 3: self.max_window_count = cfg.train.max_window_count self.window_size = cfg.train.window_size self.slice_interval = 0 self.dummy_idx = dummy_idx * self.max_window_count else: self.slice_interval = cfg.train.slice_interval self.window_size = cfg.train.max_num_words self.device = 'cuda' def _load_model(self, cfg_path, ckpt_path): cfg = OmegaConf.load(cfg_path) cfg.stage = self.mode backbone_type = cfg.model.backbone logger.info('[INFO] Checkpoint:', ckpt_path) net = get_model(cfg) load_model_weight(net, ckpt_path) net.to('cuda') net.eval() return net, cfg, backbone_type def predict(self, input_sample): if self.mode == 0: if len(input_sample['words']) == 0: return [], [], [], [] bbox, lwords, pr_class_words, pr_relations = self.combined_predict(input_sample) return [bbox], [lwords], [pr_class_words], [pr_relations] elif self.mode == 1: if len(input_sample['documents']['words']) == 0: return [], [], [], [] bbox, lwords, pr_class_words, pr_relations = self.cat_predict(input_sample) return [bbox], [lwords], [pr_class_words], [pr_relations] elif self.mode == 2: if len(input_sample['windows'][0]['words']) == 0: return [], [], [], [] bbox, lwords, pr_class_words, pr_relations = [], [], [], [] for window in input_sample['windows']: _bbox, _lwords, _pr_class_words, _pr_relations = self.combined_predict(window) bbox.append(_bbox) lwords.append(_lwords) pr_class_words.append(_pr_class_words) pr_relations.append(_pr_relations) return bbox, lwords, pr_class_words, pr_relations elif self.mode == 3: if len(input_sample["documents"]['words']) == 0: return [], [], [], [] bbox, lwords, pr_class_words, pr_relations = self.doc_predict(input_sample) return [bbox], [lwords], [pr_class_words], [pr_relations] else: raise ValueError( f"Not supported mode: {self.mode}" ) def doc_predict(self, input_sample): lwords = input_sample['documents']['words'] for idx, window in enumerate(input_sample['windows']): input_sample['windows'][idx] = {k: v.unsqueeze(0).to(self.device) for k, v in window.items() if k not in ('words', 'n_empty_windows')} # input_sample['documents'] = {k: v.unsqueeze(0).to(self.device) for k, v in input_sample['documents'].items() if k not in ('words', 'n_empty_windows')} with torch.no_grad(): head_outputs, _ = self.net(input_sample) head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()} input_sample = input_sample['documents'] itc_outputs = head_outputs["itc_outputs"] stc_outputs = head_outputs["stc_outputs"] el_outputs = head_outputs["el_outputs"] el_outputs_from_key = head_outputs["el_outputs_from_key"] pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0) pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0) pr_el_label = torch.argmax(el_outputs, -1).squeeze(0) pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0) box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0) attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0) bbox = input_sample['bbox'].squeeze(0) pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names) pr_class_words = parse_subsequent_words( pr_stc_label, attention_mask, pr_init_words, self.dummy_idx ) pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx) pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx) pr_relations = pr_relations_from_header | pr_relations_from_key return bbox, lwords, pr_class_words, pr_relations def combined_predict(self, input_sample): lwords = input_sample['words'] input_sample = {k: v.unsqueeze(0) for k, v in input_sample.items() if k not in ('words', 'img_path')} input_sample = {k: v.to(self.device) for k, v in input_sample.items()} with torch.no_grad(): head_outputs, _ = self.net(input_sample) head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()} input_sample = {k: v.detach().cpu() for k, v in input_sample.items()} itc_outputs = head_outputs["itc_outputs"] stc_outputs = head_outputs["stc_outputs"] el_outputs = head_outputs["el_outputs"] el_outputs_from_key = head_outputs["el_outputs_from_key"] pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0) pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0) pr_el_label = torch.argmax(el_outputs, -1).squeeze(0) pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0) box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0) attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0) bbox = input_sample['bbox'].squeeze(0) pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names) pr_class_words = parse_subsequent_words( pr_stc_label, attention_mask, pr_init_words, self.dummy_idx ) pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx) pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx) pr_relations = pr_relations_from_header | pr_relations_from_key return bbox, lwords, pr_class_words, pr_relations def cat_predict(self, input_sample): lwords = input_sample['documents']['words'] inputs = [] for window in input_sample['windows']: inputs.append({k: v.unsqueeze(0).cuda() for k, v in window.items() if k not in ('words', 'img_path')}) input_sample['windows'] = inputs with torch.no_grad(): head_outputs, _ = self.net(input_sample) head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items() if k not in ('embedding_tokens')} itc_outputs = head_outputs["itc_outputs"] stc_outputs = head_outputs["stc_outputs"] el_outputs = head_outputs["el_outputs"] el_outputs_from_key = head_outputs["el_outputs_from_key"] pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0) pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0) pr_el_label = torch.argmax(el_outputs, -1).squeeze(0) pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0) box_first_token_mask = input_sample['documents']['are_box_first_tokens'] attention_mask = input_sample['documents']['attention_mask_layoutxlm'] bbox = input_sample['documents']['bbox'] dummy_idx = input_sample['documents']['bbox'].shape[0] pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names) pr_class_words = parse_subsequent_words( pr_stc_label, attention_mask, pr_init_words, dummy_idx ) pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, dummy_idx) pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, dummy_idx) pr_relations = pr_relations_from_header | pr_relations_from_key return bbox, lwords, pr_class_words, pr_relations def get_ground_truth_label(self, ground_truth): # ground_truth = self.preprocessor.load_ground_truth(json_file) gt_itc_label = ground_truth['itc_labels'].squeeze(0) # [1, 512] => [512] gt_stc_label = ground_truth['stc_labels'].squeeze(0) # [1, 512] => [512] gt_el_label = ground_truth['el_labels'].squeeze(0) gt_el_label_from_key = ground_truth['el_labels_from_key'].squeeze(0) lwords = ground_truth["words"] box_first_token_mask = ground_truth['are_box_first_tokens'].squeeze(0) attention_mask = ground_truth['attention_mask'].squeeze(0) bbox = ground_truth['bbox'].squeeze(0) gt_first_words = parse_initial_words( gt_itc_label, box_first_token_mask, self.class_names ) gt_class_words = parse_subsequent_words( gt_stc_label, attention_mask, gt_first_words, self.dummy_idx ) gt_relations_from_header = parse_relations(gt_el_label, box_first_token_mask, self.dummy_idx) gt_relations_from_key = parse_relations(gt_el_label_from_key, box_first_token_mask, self.dummy_idx) gt_relations = gt_relations_from_header | gt_relations_from_key return bbox, lwords, gt_class_words, gt_relations