sbt-idp/cope2n-ai-fi/common/AnyKey_Value/model/document_kvu_model.py

186 lines
8.3 KiB
Python
Raw Normal View History

2023-11-30 11:22:16 +00:00
import torch
from torch import nn
from transformers import LayoutLMConfig, LayoutLMModel, LayoutLMTokenizer, LayoutLMv2FeatureExtractor
from transformers import LayoutLMv2Config, LayoutLMv2Model
from transformers import LayoutXLMTokenizer
from transformers import XLMRobertaConfig, AutoTokenizer, XLMRobertaModel
from model.relation_extractor import RelationExtractor
from model.kvu_model import KVUModel
from utils import load_checkpoint
class DocumentKVUModel(KVUModel):
def __init__(self, cfg):
super().__init__(cfg)
self.model_cfg = cfg.model
self.freeze = cfg.train.freeze
self.train_cfg = cfg.train
self._get_backbones(self.model_cfg.backbone)
# if 'pth' in self.model_cfg.ckpt_model_file:
# self.backbone = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone)
self._create_head()
self.loss_func = nn.CrossEntropyLoss()
def _create_head(self):
self.backbone_hidden_size = self.backbone_config.hidden_size
self.head_hidden_size = self.model_cfg.head_hidden_size
self.head_p_dropout = self.model_cfg.head_p_dropout
self.n_classes = self.model_cfg.n_classes + 1
self.repr_hiddent_size = self.backbone_hidden_size
# (1) Initial token classification
self.itc_layer = nn.Sequential(
nn.Dropout(self.head_p_dropout),
nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size),
nn.Dropout(self.head_p_dropout),
nn.Linear(self.backbone_hidden_size, self.n_classes),
)
# (2) Subsequent token classification
self.stc_layer = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (3) Linking token classification
self.relation_layer = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (4) Linking token classification
self.relation_layer_from_key = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# Classfication Layer for whole document
# (1) Initial token classification
self.itc_layer_document = nn.Sequential(
nn.Dropout(self.head_p_dropout),
nn.Linear(self.repr_hiddent_size, self.repr_hiddent_size),
nn.Dropout(self.head_p_dropout),
nn.Linear(self.repr_hiddent_size, self.n_classes),
)
# (2) Subsequent token classification
self.stc_layer_document = RelationExtractor(
n_relations=1,
backbone_hidden_size=self.repr_hiddent_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (3) Linking token classification
self.relation_layer_document = RelationExtractor(
n_relations=1,
backbone_hidden_size=self.repr_hiddent_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (4) Linking token classification
self.relation_layer_from_key_document = RelationExtractor(
n_relations=1,
backbone_hidden_size=self.repr_hiddent_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
self.itc_layer.apply(self._init_weight)
self.stc_layer.apply(self._init_weight)
self.relation_layer.apply(self._init_weight)
self.relation_layer_from_key.apply(self._init_weight)
self.itc_layer_document.apply(self._init_weight)
self.stc_layer_document.apply(self._init_weight)
self.relation_layer_document.apply(self._init_weight)
self.relation_layer_from_key_document.apply(self._init_weight)
def _get_backbones(self, config_type):
configs = {
'layoutlm': {'config': LayoutLMConfig, 'tokenizer': LayoutLMTokenizer, 'backbone': LayoutLMModel, 'feature_extrator': LayoutLMv2FeatureExtractor},
'layoutxlm': {'config': LayoutLMv2Config, 'tokenizer': LayoutXLMTokenizer, 'backbone': LayoutLMv2Model, 'feature_extrator': LayoutLMv2FeatureExtractor},
'xlm-roberta': {'config': XLMRobertaConfig, 'tokenizer': AutoTokenizer, 'backbone': XLMRobertaModel, 'feature_extrator': LayoutLMv2FeatureExtractor},
}
self.backbone_config = configs[config_type]['config'].from_pretrained(self.model_cfg.pretrained_model_path)
if config_type != 'xlm-roberta':
self.tokenizer = configs[config_type]['tokenizer'].from_pretrained(self.model_cfg.pretrained_model_path)
else:
self.tokenizer = configs[config_type]['tokenizer'].from_pretrained(self.model_cfg.pretrained_model_path, use_fast=False)
self.feature_extractor = configs[config_type]['feature_extrator'](apply_ocr=False)
self.backbone = configs[config_type]['backbone'].from_pretrained(self.model_cfg.pretrained_model_path)
def forward(self, batches):
head_outputs_list = []
loss = 0.0
for batch in batches["windows"]:
image = batch["image"]
input_ids = batch["input_ids_layoutxlm"]
bbox = batch["bbox"]
attention_mask = batch["attention_mask_layoutxlm"]
if self.freeze:
for param in self.backbone.parameters():
param.requires_grad = False
if self.model_cfg.backbone == 'layoutxlm':
backbone_outputs = self.backbone(
image=image, input_ids=input_ids, bbox=bbox, attention_mask=attention_mask
)
else:
backbone_outputs = self.backbone(input_ids, attention_mask=attention_mask)
last_hidden_states = backbone_outputs.last_hidden_state[:, :512, :]
last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous()
stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0)
el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0)
el_outputs_from_key = self.relation_layer_from_key(last_hidden_states, last_hidden_states).squeeze(0)
window_repr = last_hidden_states.transpose(0, 1).contiguous()
head_outputs = {"window_repr": window_repr,
"itc_outputs": itc_outputs,
"stc_outputs": stc_outputs,
"el_outputs": el_outputs,
"el_outputs_from_key": el_outputs_from_key}
if any(['labels' in key for key in batch.keys()]):
loss += self._get_loss(head_outputs, batch)
head_outputs_list.append(head_outputs)
batch = batches["documents"]
document_repr = torch.cat([w['window_repr'] for w in head_outputs_list], dim=1)
document_repr = document_repr.transpose(0, 1).contiguous()
itc_outputs = self.itc_layer_document(document_repr).transpose(0, 1).contiguous()
stc_outputs = self.stc_layer_document(document_repr, document_repr).squeeze(0)
el_outputs = self.relation_layer_document(document_repr, document_repr).squeeze(0)
el_outputs_from_key = self.relation_layer_from_key_document(document_repr, document_repr).squeeze(0)
head_outputs = {"itc_outputs": itc_outputs,
"stc_outputs": stc_outputs,
"el_outputs": el_outputs,
"el_outputs_from_key": el_outputs_from_key}
if any(['labels' in key for key in batch.keys()]):
loss += self._get_loss(head_outputs, batch)
return head_outputs, loss