sbt-idp/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/modules/predictor.py
2023-11-30 18:22:16 +07:00

225 lines
9.6 KiB
Python

import torch
from pathlib import Path
from omegaconf import OmegaConf
import os
from sdsvkvu.sources.utils import parse_initial_words, parse_subsequent_words, parse_relations
from sdsvkvu.model import get_model, load_model_weight
class KVUPredictor:
def __init__(self, configs):
self.mode = configs.mode
self.device = configs.device
self.pretrained_model_path = configs.model.pretrained_model_path
net, cfg = self._load_model(configs.model.config,
configs.model.checkpoint)
self.model = net
self.class_names = cfg.model.class_names
self.max_seq_length = cfg.train.max_seq_length
self.backbone_type = cfg.model.backbone
if self.mode in (3, 4):
self.slice_interval = 0
self.window_size = cfg.train.window_size
self.max_window_count = cfg.train.max_window_count
self.dummy_idx = self.max_seq_length * self.max_window_count
else:
self.slice_interval = cfg.train.slice_interval
self.window_size = cfg.train.max_num_words
self.max_window_count = 1
if self.mode == 2:
self.dummy_idx = 0 # dynamic dummy
else:
self.dummy_idx = self.max_seq_length # 512
def get_process_configs(self):
_settings = {
# "tokenizer_layoutxlm": self.model.tokenizer_layoutxlm,
# "feature_extractor": self.model.feature_extractor,
"class_names": self.class_names,
"backbone_type": self.backbone_type,
"window_size": self.window_size,
"slice_interval": self.slice_interval,
"max_window_count": self.max_window_count,
"max_seq_length": self.max_seq_length,
"device": self.device,
"mode": self.mode
}
feature_extractor = self.model.feature_extractor
if self.mode in (3, 4):
tokenizer_layoutxlm = self.model.tokenizer
else:
tokenizer_layoutxlm = self.model.tokenizer_layoutxlm
return OmegaConf.create(_settings), tokenizer_layoutxlm, feature_extractor
def _load_model(self, cfg_path, ckpt_path):
cfg = OmegaConf.load(cfg_path)
if self.pretrained_model_path is not None and os.path.exists(self.pretrained_model_path):
cfg.model.pretrained_model_path = self.pretrained_model_path
print("[INFO] Load pretrained backbone at:", cfg.model.pretrained_model_path)
cfg.mode = self.mode
net = get_model(cfg)
load_model_weight(net, ckpt_path)
net.to(self.device)
net.eval()
return net, cfg
def predict(self, input_sample):
if self.mode == 0: # Normal
bbox, lwords, pr_class_words, pr_relations = self.com_predict(input_sample)
return [bbox], [lwords], [pr_class_words], [pr_relations]
elif self.mode == 1: # Full - tokens
bbox, lwords, pr_class_words, pr_relations = self.cat_predict(input_sample)
return [bbox], [lwords], [pr_class_words], [pr_relations]
elif self.mode == 2: # Sliding
bbox, lwords, pr_class_words, pr_relations = [], [], [], []
for window in input_sample['windows']:
_bbox, _lwords, _pr_class_words, _pr_relations = self.com_predict(window)
bbox.append(_bbox)
lwords.append(_lwords)
pr_class_words.append(_pr_class_words)
pr_relations.append(_pr_relations)
return bbox, lwords, pr_class_words, pr_relations
elif self.mode == 3: # Document
bbox, lwords, pr_class_words, pr_relations = self.doc_predict(input_sample)
return [bbox], [lwords], [pr_class_words], [pr_relations]
elif self.mode == 4: # SBT
bbox, lwords, pr_class_words, pr_relations = self.sbt_predict(input_sample)
return [bbox], [lwords], [pr_class_words], [pr_relations]
else:
raise ValueError(f"Not supported mode: {self.mode }")
def doc_predict(self, input_sample):
lwords = input_sample['documents']['words']
for idx, window in enumerate(input_sample['windows']):
input_sample['windows'][idx] = {k: v.unsqueeze(0).to(self.device) for k, v in window.items() if k not in ('words', 'n_empty_windows')}
with torch.no_grad():
head_outputs, _ = self.model(input_sample)
input_sample = input_sample['documents']
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()}
# input_sample = {k: v.detach().cpu() for k, v in input_sample.items()}
bbox = input_sample['bbox'].squeeze(0)
pr_class_words, pr_relations = self.kvu_parser(input_sample, head_outputs)
return bbox, lwords, pr_class_words, pr_relations
def com_predict(self, input_sample):
lwords = input_sample['words']
input_sample = {k: v.unsqueeze(0) for k, v in input_sample.items() if k not in ('words', 'img_path')}
input_sample = {k: v.to(self.device) for k, v in input_sample.items()}
with torch.no_grad():
head_outputs, _ = self.model(input_sample)
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()}
input_sample = {k: v.detach().cpu() for k, v in input_sample.items()}
bbox = input_sample['bbox'].squeeze(0)
pr_class_words, pr_relations = self.kvu_parser(input_sample, head_outputs)
return bbox, lwords, pr_class_words, pr_relations
def cat_predict(self, input_sample):
lwords = input_sample['documents']['words']
inputs = []
for window in input_sample['windows']:
inputs.append({k: v.unsqueeze(0).cuda() for k, v in window.items() if k not in ('words', 'img_path')})
input_sample['windows'] = inputs
with torch.no_grad():
head_outputs, _ = self.model(input_sample)
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items() if k not in ('embedding_tokens')}
input_sample = {k: v.unsqueeze(0) for k, v in input_sample["documents"].items()}
bbox = input_sample['bbox'].squeeze(0)
self.dummy_idx = bbox.shape[0]
pr_class_words, pr_relations = self.kvu_parser(input_sample, head_outputs)
return bbox, lwords, pr_class_words, pr_relations
def kvu_parser(self, input_sample, head_outputs):
itc_outputs = head_outputs["itc_outputs"]
stc_outputs = head_outputs["stc_outputs"]
el_outputs = head_outputs["el_outputs"]
el_outputs_from_key = head_outputs["el_outputs_from_key"]
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)
box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0)
attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0)
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
pr_class_words = parse_subsequent_words(
pr_stc_label, attention_mask, pr_init_words, self.dummy_idx
)
pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx)
pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx)
pr_relations = pr_relations_from_header | pr_relations_from_key
return pr_class_words, pr_relations
def sbt_predict(self, input_sample):
lwords = input_sample['documents']['words']
for idx, window in enumerate(input_sample['windows']):
input_sample['windows'][idx] = {k: v.unsqueeze(0).to(self.device) for k, v in window.items() if k not in ('words', 'n_empty_windows')}
with torch.no_grad():
head_outputs, _ = self.model(input_sample)
input_sample = input_sample['documents']
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()}
# input_sample = {k: v.detach().cpu() for k, v in input_sample.items()}
bbox = input_sample['bbox'].squeeze(0)
pr_class_words, pr_relations = self.sbt_parser(input_sample, head_outputs)
return bbox, lwords, pr_class_words, pr_relations
def sbt_parser(self, input_sample, head_outputs):
itc_outputs = head_outputs["itc_outputs"]
stc_outputs = head_outputs["stc_outputs"]
el_outputs = head_outputs["el_outputs"]
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0)
attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0)
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
pr_class_words = parse_subsequent_words(
pr_stc_label, attention_mask, pr_init_words, self.dummy_idx
)
pr_relations = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx)
return pr_class_words, pr_relations