225 lines
9.6 KiB
Python
225 lines
9.6 KiB
Python
import torch
|
|
from pathlib import Path
|
|
from omegaconf import OmegaConf
|
|
|
|
import os
|
|
from sdsvkvu.sources.utils import parse_initial_words, parse_subsequent_words, parse_relations
|
|
from sdsvkvu.model import get_model, load_model_weight
|
|
|
|
|
|
class KVUPredictor:
|
|
def __init__(self, configs):
|
|
self.mode = configs.mode
|
|
self.device = configs.device
|
|
self.pretrained_model_path = configs.model.pretrained_model_path
|
|
net, cfg = self._load_model(configs.model.config,
|
|
configs.model.checkpoint)
|
|
|
|
self.model = net
|
|
self.class_names = cfg.model.class_names
|
|
self.max_seq_length = cfg.train.max_seq_length
|
|
self.backbone_type = cfg.model.backbone
|
|
|
|
if self.mode in (3, 4):
|
|
self.slice_interval = 0
|
|
self.window_size = cfg.train.window_size
|
|
self.max_window_count = cfg.train.max_window_count
|
|
self.dummy_idx = self.max_seq_length * self.max_window_count
|
|
|
|
else:
|
|
self.slice_interval = cfg.train.slice_interval
|
|
self.window_size = cfg.train.max_num_words
|
|
self.max_window_count = 1
|
|
if self.mode == 2:
|
|
self.dummy_idx = 0 # dynamic dummy
|
|
else:
|
|
self.dummy_idx = self.max_seq_length # 512
|
|
|
|
|
|
def get_process_configs(self):
|
|
_settings = {
|
|
# "tokenizer_layoutxlm": self.model.tokenizer_layoutxlm,
|
|
# "feature_extractor": self.model.feature_extractor,
|
|
"class_names": self.class_names,
|
|
"backbone_type": self.backbone_type,
|
|
"window_size": self.window_size,
|
|
"slice_interval": self.slice_interval,
|
|
"max_window_count": self.max_window_count,
|
|
"max_seq_length": self.max_seq_length,
|
|
"device": self.device,
|
|
"mode": self.mode
|
|
}
|
|
|
|
feature_extractor = self.model.feature_extractor
|
|
if self.mode in (3, 4):
|
|
tokenizer_layoutxlm = self.model.tokenizer
|
|
else:
|
|
tokenizer_layoutxlm = self.model.tokenizer_layoutxlm
|
|
|
|
return OmegaConf.create(_settings), tokenizer_layoutxlm, feature_extractor
|
|
|
|
|
|
def _load_model(self, cfg_path, ckpt_path):
|
|
cfg = OmegaConf.load(cfg_path)
|
|
|
|
if self.pretrained_model_path is not None and os.path.exists(self.pretrained_model_path):
|
|
cfg.model.pretrained_model_path = self.pretrained_model_path
|
|
print("[INFO] Load pretrained backbone at:", cfg.model.pretrained_model_path)
|
|
|
|
cfg.mode = self.mode
|
|
net = get_model(cfg)
|
|
load_model_weight(net, ckpt_path)
|
|
net.to(self.device)
|
|
net.eval()
|
|
return net, cfg
|
|
|
|
def predict(self, input_sample):
|
|
if self.mode == 0: # Normal
|
|
bbox, lwords, pr_class_words, pr_relations = self.com_predict(input_sample)
|
|
return [bbox], [lwords], [pr_class_words], [pr_relations]
|
|
|
|
elif self.mode == 1: # Full - tokens
|
|
bbox, lwords, pr_class_words, pr_relations = self.cat_predict(input_sample)
|
|
return [bbox], [lwords], [pr_class_words], [pr_relations]
|
|
|
|
elif self.mode == 2: # Sliding
|
|
bbox, lwords, pr_class_words, pr_relations = [], [], [], []
|
|
for window in input_sample['windows']:
|
|
_bbox, _lwords, _pr_class_words, _pr_relations = self.com_predict(window)
|
|
bbox.append(_bbox)
|
|
lwords.append(_lwords)
|
|
pr_class_words.append(_pr_class_words)
|
|
pr_relations.append(_pr_relations)
|
|
return bbox, lwords, pr_class_words, pr_relations
|
|
|
|
elif self.mode == 3: # Document
|
|
bbox, lwords, pr_class_words, pr_relations = self.doc_predict(input_sample)
|
|
return [bbox], [lwords], [pr_class_words], [pr_relations]
|
|
|
|
elif self.mode == 4: # SBT
|
|
bbox, lwords, pr_class_words, pr_relations = self.sbt_predict(input_sample)
|
|
return [bbox], [lwords], [pr_class_words], [pr_relations]
|
|
else:
|
|
raise ValueError(f"Not supported mode: {self.mode }")
|
|
|
|
def doc_predict(self, input_sample):
|
|
lwords = input_sample['documents']['words']
|
|
for idx, window in enumerate(input_sample['windows']):
|
|
input_sample['windows'][idx] = {k: v.unsqueeze(0).to(self.device) for k, v in window.items() if k not in ('words', 'n_empty_windows')}
|
|
|
|
with torch.no_grad():
|
|
head_outputs, _ = self.model(input_sample)
|
|
|
|
input_sample = input_sample['documents']
|
|
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()}
|
|
# input_sample = {k: v.detach().cpu() for k, v in input_sample.items()}
|
|
|
|
bbox = input_sample['bbox'].squeeze(0)
|
|
pr_class_words, pr_relations = self.kvu_parser(input_sample, head_outputs)
|
|
|
|
return bbox, lwords, pr_class_words, pr_relations
|
|
|
|
|
|
def com_predict(self, input_sample):
|
|
lwords = input_sample['words']
|
|
input_sample = {k: v.unsqueeze(0) for k, v in input_sample.items() if k not in ('words', 'img_path')}
|
|
input_sample = {k: v.to(self.device) for k, v in input_sample.items()}
|
|
|
|
with torch.no_grad():
|
|
head_outputs, _ = self.model(input_sample)
|
|
|
|
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()}
|
|
input_sample = {k: v.detach().cpu() for k, v in input_sample.items()}
|
|
|
|
|
|
bbox = input_sample['bbox'].squeeze(0)
|
|
pr_class_words, pr_relations = self.kvu_parser(input_sample, head_outputs)
|
|
|
|
return bbox, lwords, pr_class_words, pr_relations
|
|
|
|
|
|
def cat_predict(self, input_sample):
|
|
lwords = input_sample['documents']['words']
|
|
inputs = []
|
|
for window in input_sample['windows']:
|
|
inputs.append({k: v.unsqueeze(0).cuda() for k, v in window.items() if k not in ('words', 'img_path')})
|
|
input_sample['windows'] = inputs
|
|
|
|
with torch.no_grad():
|
|
head_outputs, _ = self.model(input_sample)
|
|
|
|
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items() if k not in ('embedding_tokens')}
|
|
|
|
|
|
input_sample = {k: v.unsqueeze(0) for k, v in input_sample["documents"].items()}
|
|
|
|
bbox = input_sample['bbox'].squeeze(0)
|
|
self.dummy_idx = bbox.shape[0]
|
|
pr_class_words, pr_relations = self.kvu_parser(input_sample, head_outputs)
|
|
return bbox, lwords, pr_class_words, pr_relations
|
|
|
|
|
|
def kvu_parser(self, input_sample, head_outputs):
|
|
itc_outputs = head_outputs["itc_outputs"]
|
|
stc_outputs = head_outputs["stc_outputs"]
|
|
el_outputs = head_outputs["el_outputs"]
|
|
el_outputs_from_key = head_outputs["el_outputs_from_key"]
|
|
|
|
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
|
|
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
|
|
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
|
|
pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)
|
|
|
|
box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0)
|
|
attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0)
|
|
|
|
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
|
|
pr_class_words = parse_subsequent_words(
|
|
pr_stc_label, attention_mask, pr_init_words, self.dummy_idx
|
|
)
|
|
|
|
pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx)
|
|
pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx)
|
|
pr_relations = pr_relations_from_header | pr_relations_from_key
|
|
|
|
return pr_class_words, pr_relations
|
|
|
|
|
|
def sbt_predict(self, input_sample):
|
|
lwords = input_sample['documents']['words']
|
|
for idx, window in enumerate(input_sample['windows']):
|
|
input_sample['windows'][idx] = {k: v.unsqueeze(0).to(self.device) for k, v in window.items() if k not in ('words', 'n_empty_windows')}
|
|
|
|
with torch.no_grad():
|
|
head_outputs, _ = self.model(input_sample)
|
|
|
|
input_sample = input_sample['documents']
|
|
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()}
|
|
# input_sample = {k: v.detach().cpu() for k, v in input_sample.items()}
|
|
|
|
bbox = input_sample['bbox'].squeeze(0)
|
|
pr_class_words, pr_relations = self.sbt_parser(input_sample, head_outputs)
|
|
|
|
return bbox, lwords, pr_class_words, pr_relations
|
|
|
|
|
|
def sbt_parser(self, input_sample, head_outputs):
|
|
itc_outputs = head_outputs["itc_outputs"]
|
|
stc_outputs = head_outputs["stc_outputs"]
|
|
el_outputs = head_outputs["el_outputs"]
|
|
|
|
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
|
|
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
|
|
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
|
|
|
|
box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0)
|
|
attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0)
|
|
|
|
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
|
|
pr_class_words = parse_subsequent_words(
|
|
pr_stc_label, attention_mask, pr_init_words, self.dummy_idx
|
|
)
|
|
|
|
pr_relations = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx)
|
|
|
|
return pr_class_words, pr_relations |