228 lines
10 KiB
Python
Executable File
228 lines
10 KiB
Python
Executable File
from omegaconf import OmegaConf
|
|
import torch
|
|
# from functions import get_colormap, visualize
|
|
import sys
|
|
sys.path.append('/mnt/ssd1T/tuanlv/02.KeyValueUnderstanding/') #TODO: ??????
|
|
|
|
from lightning_modules.classifier_module import parse_initial_words, parse_subsequent_words, parse_relations
|
|
from model import get_model
|
|
from utils import load_model_weight
|
|
|
|
|
|
class KVUPredictor:
|
|
def __init__(self, configs, class_names, dummy_idx, mode=0):
|
|
cfg_path = configs['cfg']
|
|
ckpt_path = configs['ckpt']
|
|
|
|
self.class_names = class_names
|
|
self.dummy_idx = dummy_idx
|
|
self.mode = mode
|
|
|
|
print('[INFO] Loading Key-Value Understanding model ...')
|
|
self.net, cfg, self.backbone_type = self._load_model(cfg_path, ckpt_path)
|
|
print("[INFO] Loaded model")
|
|
|
|
if mode == 3:
|
|
self.max_window_count = cfg.train.max_window_count
|
|
self.window_size = cfg.train.window_size
|
|
self.slice_interval = 0
|
|
self.dummy_idx = dummy_idx * self.max_window_count
|
|
else:
|
|
self.slice_interval = cfg.train.slice_interval
|
|
self.window_size = cfg.train.max_num_words
|
|
|
|
|
|
self.device = 'cuda'
|
|
|
|
def _load_model(self, cfg_path, ckpt_path):
|
|
cfg = OmegaConf.load(cfg_path)
|
|
cfg.stage = self.mode
|
|
backbone_type = cfg.model.backbone
|
|
|
|
print('[INFO] Checkpoint:', ckpt_path)
|
|
net = get_model(cfg)
|
|
load_model_weight(net, ckpt_path)
|
|
net.to('cuda')
|
|
net.eval()
|
|
return net, cfg, backbone_type
|
|
|
|
def predict(self, input_sample):
|
|
if self.mode == 0:
|
|
if len(input_sample['words']) == 0:
|
|
return [], [], [], []
|
|
bbox, lwords, pr_class_words, pr_relations = self.combined_predict(input_sample)
|
|
return [bbox], [lwords], [pr_class_words], [pr_relations]
|
|
|
|
elif self.mode == 1:
|
|
if len(input_sample['documents']['words']) == 0:
|
|
return [], [], [], []
|
|
bbox, lwords, pr_class_words, pr_relations = self.cat_predict(input_sample)
|
|
return [bbox], [lwords], [pr_class_words], [pr_relations]
|
|
|
|
elif self.mode == 2:
|
|
if len(input_sample['windows'][0]['words']) == 0:
|
|
return [], [], [], []
|
|
bbox, lwords, pr_class_words, pr_relations = [], [], [], []
|
|
for window in input_sample['windows']:
|
|
_bbox, _lwords, _pr_class_words, _pr_relations = self.combined_predict(window)
|
|
bbox.append(_bbox)
|
|
lwords.append(_lwords)
|
|
pr_class_words.append(_pr_class_words)
|
|
pr_relations.append(_pr_relations)
|
|
return bbox, lwords, pr_class_words, pr_relations
|
|
|
|
elif self.mode == 3:
|
|
if len(input_sample["documents"]['words']) == 0:
|
|
return [], [], [], []
|
|
bbox, lwords, pr_class_words, pr_relations = self.doc_predict(input_sample)
|
|
return [bbox], [lwords], [pr_class_words], [pr_relations]
|
|
|
|
else:
|
|
raise ValueError(
|
|
f"Not supported mode: {self.mode}"
|
|
)
|
|
|
|
def doc_predict(self, input_sample):
|
|
lwords = input_sample['documents']['words']
|
|
for idx, window in enumerate(input_sample['windows']):
|
|
input_sample['windows'][idx] = {k: v.unsqueeze(0).to(self.device) for k, v in window.items() if k not in ('words', 'n_empty_windows')}
|
|
|
|
# input_sample['documents'] = {k: v.unsqueeze(0).to(self.device) for k, v in input_sample['documents'].items() if k not in ('words', 'n_empty_windows')}
|
|
with torch.no_grad():
|
|
head_outputs, _ = self.net(input_sample)
|
|
|
|
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()}
|
|
input_sample = input_sample['documents']
|
|
|
|
itc_outputs = head_outputs["itc_outputs"]
|
|
stc_outputs = head_outputs["stc_outputs"]
|
|
el_outputs = head_outputs["el_outputs"]
|
|
el_outputs_from_key = head_outputs["el_outputs_from_key"]
|
|
|
|
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
|
|
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
|
|
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
|
|
pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)
|
|
|
|
box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0)
|
|
attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0)
|
|
bbox = input_sample['bbox'].squeeze(0)
|
|
|
|
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
|
|
pr_class_words = parse_subsequent_words(
|
|
pr_stc_label, attention_mask, pr_init_words, self.dummy_idx
|
|
)
|
|
|
|
pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx)
|
|
pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx)
|
|
pr_relations = pr_relations_from_header | pr_relations_from_key
|
|
|
|
return bbox, lwords, pr_class_words, pr_relations
|
|
|
|
|
|
def combined_predict(self, input_sample):
|
|
lwords = input_sample['words']
|
|
input_sample = {k: v.unsqueeze(0) for k, v in input_sample.items() if k not in ('words', 'img_path')}
|
|
|
|
input_sample = {k: v.to(self.device) for k, v in input_sample.items()}
|
|
|
|
|
|
with torch.no_grad():
|
|
head_outputs, _ = self.net(input_sample)
|
|
|
|
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()}
|
|
input_sample = {k: v.detach().cpu() for k, v in input_sample.items()}
|
|
|
|
|
|
itc_outputs = head_outputs["itc_outputs"]
|
|
stc_outputs = head_outputs["stc_outputs"]
|
|
el_outputs = head_outputs["el_outputs"]
|
|
el_outputs_from_key = head_outputs["el_outputs_from_key"]
|
|
|
|
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
|
|
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
|
|
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
|
|
pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)
|
|
|
|
box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0)
|
|
attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0)
|
|
bbox = input_sample['bbox'].squeeze(0)
|
|
|
|
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
|
|
pr_class_words = parse_subsequent_words(
|
|
pr_stc_label, attention_mask, pr_init_words, self.dummy_idx
|
|
)
|
|
|
|
pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx)
|
|
pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx)
|
|
pr_relations = pr_relations_from_header | pr_relations_from_key
|
|
|
|
return bbox, lwords, pr_class_words, pr_relations
|
|
|
|
def cat_predict(self, input_sample):
|
|
lwords = input_sample['documents']['words']
|
|
|
|
inputs = []
|
|
for window in input_sample['windows']:
|
|
inputs.append({k: v.unsqueeze(0).cuda() for k, v in window.items() if k not in ('words', 'img_path')})
|
|
input_sample['windows'] = inputs
|
|
|
|
with torch.no_grad():
|
|
head_outputs, _ = self.net(input_sample)
|
|
|
|
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items() if k not in ('embedding_tokens')}
|
|
|
|
itc_outputs = head_outputs["itc_outputs"]
|
|
stc_outputs = head_outputs["stc_outputs"]
|
|
el_outputs = head_outputs["el_outputs"]
|
|
el_outputs_from_key = head_outputs["el_outputs_from_key"]
|
|
|
|
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
|
|
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
|
|
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
|
|
pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)
|
|
|
|
box_first_token_mask = input_sample['documents']['are_box_first_tokens']
|
|
attention_mask = input_sample['documents']['attention_mask_layoutxlm']
|
|
bbox = input_sample['documents']['bbox']
|
|
|
|
dummy_idx = input_sample['documents']['bbox'].shape[0]
|
|
|
|
|
|
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
|
|
pr_class_words = parse_subsequent_words(
|
|
pr_stc_label, attention_mask, pr_init_words, dummy_idx
|
|
)
|
|
|
|
pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, dummy_idx)
|
|
pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, dummy_idx)
|
|
pr_relations = pr_relations_from_header | pr_relations_from_key
|
|
|
|
return bbox, lwords, pr_class_words, pr_relations
|
|
|
|
|
|
def get_ground_truth_label(self, ground_truth):
|
|
# ground_truth = self.preprocessor.load_ground_truth(json_file)
|
|
gt_itc_label = ground_truth['itc_labels'].squeeze(0) # [1, 512] => [512]
|
|
gt_stc_label = ground_truth['stc_labels'].squeeze(0) # [1, 512] => [512]
|
|
gt_el_label = ground_truth['el_labels'].squeeze(0)
|
|
|
|
gt_el_label_from_key = ground_truth['el_labels_from_key'].squeeze(0)
|
|
lwords = ground_truth["words"]
|
|
|
|
box_first_token_mask = ground_truth['are_box_first_tokens'].squeeze(0)
|
|
attention_mask = ground_truth['attention_mask'].squeeze(0)
|
|
|
|
bbox = ground_truth['bbox'].squeeze(0)
|
|
gt_first_words = parse_initial_words(
|
|
gt_itc_label, box_first_token_mask, self.class_names
|
|
)
|
|
gt_class_words = parse_subsequent_words(
|
|
gt_stc_label, attention_mask, gt_first_words, self.dummy_idx
|
|
)
|
|
|
|
gt_relations_from_header = parse_relations(gt_el_label, box_first_token_mask, self.dummy_idx)
|
|
gt_relations_from_key = parse_relations(gt_el_label_from_key, box_first_token_mask, self.dummy_idx)
|
|
gt_relations = gt_relations_from_header | gt_relations_from_key
|
|
|
|
return bbox, lwords, gt_class_words, gt_relations |