sbt-idp/cope2n-ai-fi/common/AnyKey_Value/model/document_kvu_model.py
2023-11-30 18:22:16 +07:00

186 lines
8.3 KiB
Python
Executable File

import torch
from torch import nn
from transformers import LayoutLMConfig, LayoutLMModel, LayoutLMTokenizer, LayoutLMv2FeatureExtractor
from transformers import LayoutLMv2Config, LayoutLMv2Model
from transformers import LayoutXLMTokenizer
from transformers import XLMRobertaConfig, AutoTokenizer, XLMRobertaModel
from model.relation_extractor import RelationExtractor
from model.kvu_model import KVUModel
from utils import load_checkpoint
class DocumentKVUModel(KVUModel):
def __init__(self, cfg):
super().__init__(cfg)
self.model_cfg = cfg.model
self.freeze = cfg.train.freeze
self.train_cfg = cfg.train
self._get_backbones(self.model_cfg.backbone)
# if 'pth' in self.model_cfg.ckpt_model_file:
# self.backbone = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone)
self._create_head()
self.loss_func = nn.CrossEntropyLoss()
def _create_head(self):
self.backbone_hidden_size = self.backbone_config.hidden_size
self.head_hidden_size = self.model_cfg.head_hidden_size
self.head_p_dropout = self.model_cfg.head_p_dropout
self.n_classes = self.model_cfg.n_classes + 1
self.repr_hiddent_size = self.backbone_hidden_size
# (1) Initial token classification
self.itc_layer = nn.Sequential(
nn.Dropout(self.head_p_dropout),
nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size),
nn.Dropout(self.head_p_dropout),
nn.Linear(self.backbone_hidden_size, self.n_classes),
)
# (2) Subsequent token classification
self.stc_layer = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (3) Linking token classification
self.relation_layer = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (4) Linking token classification
self.relation_layer_from_key = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# Classfication Layer for whole document
# (1) Initial token classification
self.itc_layer_document = nn.Sequential(
nn.Dropout(self.head_p_dropout),
nn.Linear(self.repr_hiddent_size, self.repr_hiddent_size),
nn.Dropout(self.head_p_dropout),
nn.Linear(self.repr_hiddent_size, self.n_classes),
)
# (2) Subsequent token classification
self.stc_layer_document = RelationExtractor(
n_relations=1,
backbone_hidden_size=self.repr_hiddent_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (3) Linking token classification
self.relation_layer_document = RelationExtractor(
n_relations=1,
backbone_hidden_size=self.repr_hiddent_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (4) Linking token classification
self.relation_layer_from_key_document = RelationExtractor(
n_relations=1,
backbone_hidden_size=self.repr_hiddent_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
self.itc_layer.apply(self._init_weight)
self.stc_layer.apply(self._init_weight)
self.relation_layer.apply(self._init_weight)
self.relation_layer_from_key.apply(self._init_weight)
self.itc_layer_document.apply(self._init_weight)
self.stc_layer_document.apply(self._init_weight)
self.relation_layer_document.apply(self._init_weight)
self.relation_layer_from_key_document.apply(self._init_weight)
def _get_backbones(self, config_type):
configs = {
'layoutlm': {'config': LayoutLMConfig, 'tokenizer': LayoutLMTokenizer, 'backbone': LayoutLMModel, 'feature_extrator': LayoutLMv2FeatureExtractor},
'layoutxlm': {'config': LayoutLMv2Config, 'tokenizer': LayoutXLMTokenizer, 'backbone': LayoutLMv2Model, 'feature_extrator': LayoutLMv2FeatureExtractor},
'xlm-roberta': {'config': XLMRobertaConfig, 'tokenizer': AutoTokenizer, 'backbone': XLMRobertaModel, 'feature_extrator': LayoutLMv2FeatureExtractor},
}
self.backbone_config = configs[config_type]['config'].from_pretrained(self.model_cfg.pretrained_model_path)
if config_type != 'xlm-roberta':
self.tokenizer = configs[config_type]['tokenizer'].from_pretrained(self.model_cfg.pretrained_model_path)
else:
self.tokenizer = configs[config_type]['tokenizer'].from_pretrained(self.model_cfg.pretrained_model_path, use_fast=False)
self.feature_extractor = configs[config_type]['feature_extrator'](apply_ocr=False)
self.backbone = configs[config_type]['backbone'].from_pretrained(self.model_cfg.pretrained_model_path)
def forward(self, batches):
head_outputs_list = []
loss = 0.0
for batch in batches["windows"]:
image = batch["image"]
input_ids = batch["input_ids_layoutxlm"]
bbox = batch["bbox"]
attention_mask = batch["attention_mask_layoutxlm"]
if self.freeze:
for param in self.backbone.parameters():
param.requires_grad = False
if self.model_cfg.backbone == 'layoutxlm':
backbone_outputs = self.backbone(
image=image, input_ids=input_ids, bbox=bbox, attention_mask=attention_mask
)
else:
backbone_outputs = self.backbone(input_ids, attention_mask=attention_mask)
last_hidden_states = backbone_outputs.last_hidden_state[:, :512, :]
last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous()
stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0)
el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0)
el_outputs_from_key = self.relation_layer_from_key(last_hidden_states, last_hidden_states).squeeze(0)
window_repr = last_hidden_states.transpose(0, 1).contiguous()
head_outputs = {"window_repr": window_repr,
"itc_outputs": itc_outputs,
"stc_outputs": stc_outputs,
"el_outputs": el_outputs,
"el_outputs_from_key": el_outputs_from_key}
if any(['labels' in key for key in batch.keys()]):
loss += self._get_loss(head_outputs, batch)
head_outputs_list.append(head_outputs)
batch = batches["documents"]
document_repr = torch.cat([w['window_repr'] for w in head_outputs_list], dim=1)
document_repr = document_repr.transpose(0, 1).contiguous()
itc_outputs = self.itc_layer_document(document_repr).transpose(0, 1).contiguous()
stc_outputs = self.stc_layer_document(document_repr, document_repr).squeeze(0)
el_outputs = self.relation_layer_document(document_repr, document_repr).squeeze(0)
el_outputs_from_key = self.relation_layer_from_key_document(document_repr, document_repr).squeeze(0)
head_outputs = {"itc_outputs": itc_outputs,
"stc_outputs": stc_outputs,
"el_outputs": el_outputs,
"el_outputs_from_key": el_outputs_from_key}
if any(['labels' in key for key in batch.keys()]):
loss += self._get_loss(head_outputs, batch)
return head_outputs, loss