157 lines
6.2 KiB
Python
157 lines
6.2 KiB
Python
|
import torch
|
||
|
from torch import nn
|
||
|
from transformers import LayoutLMConfig, LayoutLMModel, LayoutLMTokenizer, LayoutLMv2FeatureExtractor
|
||
|
from transformers import LayoutLMv2Config, LayoutLMv2Model
|
||
|
from sdsvkvu.model.relation_extractor import RelationExtractor
|
||
|
from sdsvkvu.model.kvu_model import KVUModel
|
||
|
# from utils import load_checkpoint
|
||
|
|
||
|
|
||
|
class SBTModel(KVUModel):
|
||
|
def __init__(self, cfg):
|
||
|
super().__init__(cfg=cfg)
|
||
|
|
||
|
self.model_cfg = cfg.model
|
||
|
self.freeze = cfg.train.freeze
|
||
|
self.train_cfg = cfg.train
|
||
|
self.n_classes = len(self.model_cfg.class_names)
|
||
|
|
||
|
self._get_backbones(self.model_cfg.backbone)
|
||
|
self._create_head()
|
||
|
|
||
|
self.loss_func = nn.CrossEntropyLoss()
|
||
|
|
||
|
|
||
|
def _create_head(self):
|
||
|
self.backbone_hidden_size = self.backbone_config.hidden_size
|
||
|
self.head_hidden_size = self.model_cfg.head_hidden_size
|
||
|
self.head_p_dropout = self.model_cfg.head_p_dropout
|
||
|
# self.n_classes = self.model_cfg.n_classes + 1
|
||
|
# self.relations = self.model_cfg.n_relations
|
||
|
self.repr_hiddent_size = self.backbone_hidden_size
|
||
|
|
||
|
# (1) Initial token classification
|
||
|
self.itc_layer = nn.Sequential(
|
||
|
nn.Dropout(self.head_p_dropout),
|
||
|
nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size),
|
||
|
nn.Dropout(self.head_p_dropout),
|
||
|
nn.Linear(self.backbone_hidden_size, self.n_classes),
|
||
|
)
|
||
|
# (2) Subsequent token classification
|
||
|
self.stc_layer = RelationExtractor(
|
||
|
n_relations=1, #1
|
||
|
backbone_hidden_size=self.backbone_hidden_size,
|
||
|
head_hidden_size=self.head_hidden_size,
|
||
|
head_p_dropout=self.head_p_dropout,
|
||
|
)
|
||
|
|
||
|
# (3) Linking token classification
|
||
|
self.relation_layer = RelationExtractor(
|
||
|
n_relations=1, #1
|
||
|
backbone_hidden_size=self.backbone_hidden_size,
|
||
|
head_hidden_size=self.head_hidden_size,
|
||
|
head_p_dropout=self.head_p_dropout,
|
||
|
)
|
||
|
|
||
|
# Classfication Layer for whole document
|
||
|
self.itc_layer_document = nn.Sequential(
|
||
|
nn.Dropout(self.head_p_dropout),
|
||
|
nn.Linear(self.repr_hiddent_size, self.repr_hiddent_size),
|
||
|
nn.Dropout(self.head_p_dropout),
|
||
|
nn.Linear(self.repr_hiddent_size, self.n_classes),
|
||
|
)
|
||
|
|
||
|
self.stc_layer_document = RelationExtractor(
|
||
|
n_relations=1,
|
||
|
backbone_hidden_size=self.repr_hiddent_size,
|
||
|
head_hidden_size=self.head_hidden_size,
|
||
|
head_p_dropout=self.head_p_dropout,
|
||
|
)
|
||
|
|
||
|
self.relation_layer_document = RelationExtractor(
|
||
|
n_relations=1,
|
||
|
backbone_hidden_size=self.repr_hiddent_size,
|
||
|
head_hidden_size=self.head_hidden_size,
|
||
|
head_p_dropout=self.head_p_dropout,
|
||
|
)
|
||
|
|
||
|
self.itc_layer.apply(self._init_weight)
|
||
|
self.stc_layer.apply(self._init_weight)
|
||
|
self.relation_layer.apply(self._init_weight)
|
||
|
|
||
|
self.itc_layer_document.apply(self._init_weight)
|
||
|
self.stc_layer_document.apply(self._init_weight)
|
||
|
self.relation_layer_document.apply(self._init_weight)
|
||
|
|
||
|
|
||
|
|
||
|
def forward(self, batches):
|
||
|
head_outputs_list = []
|
||
|
loss = 0.
|
||
|
for batch in batches["windows"]:
|
||
|
image = batch["image"]
|
||
|
input_ids = batch["input_ids_layoutxlm"]
|
||
|
bbox = batch["bbox"]
|
||
|
attention_mask = batch["attention_mask_layoutxlm"]
|
||
|
|
||
|
if self.freeze:
|
||
|
for param in self.backbone.parameters():
|
||
|
param.requires_grad = False
|
||
|
|
||
|
if self.model_cfg.backbone == 'layoutxlm':
|
||
|
backbone_outputs = self.backbone(
|
||
|
image=image, input_ids=input_ids, bbox=bbox, attention_mask=attention_mask
|
||
|
)
|
||
|
else:
|
||
|
backbone_outputs = self.backbone(input_ids, attention_mask=attention_mask)
|
||
|
|
||
|
last_hidden_states = backbone_outputs.last_hidden_state[:, :512, :]
|
||
|
last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
|
||
|
|
||
|
itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous()
|
||
|
stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0)
|
||
|
el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0)
|
||
|
|
||
|
window_repr = last_hidden_states.transpose(0, 1).contiguous()
|
||
|
|
||
|
head_outputs = {"window_repr": window_repr,
|
||
|
"itc_outputs": itc_outputs,
|
||
|
"stc_outputs": stc_outputs,
|
||
|
"el_outputs": el_outputs,}
|
||
|
|
||
|
if any(['labels' in key for key in batch.keys()]):
|
||
|
loss += self._get_loss(head_outputs, batch)
|
||
|
|
||
|
head_outputs_list.append(head_outputs)
|
||
|
|
||
|
batch = batches["documents"]
|
||
|
|
||
|
document_repr = torch.cat([w['window_repr'] for w in head_outputs_list], dim=1)
|
||
|
document_repr = document_repr.transpose(0, 1).contiguous()
|
||
|
|
||
|
itc_outputs = self.itc_layer_document(document_repr).transpose(0, 1).contiguous()
|
||
|
stc_outputs = self.stc_layer_document(document_repr, document_repr).squeeze(0)
|
||
|
el_outputs = self.relation_layer_document(document_repr, document_repr).squeeze(0)
|
||
|
|
||
|
head_outputs = {"itc_outputs": itc_outputs,
|
||
|
"stc_outputs": stc_outputs,
|
||
|
"el_outputs": el_outputs}
|
||
|
|
||
|
if any(['labels' in key for key in batch.keys()]):
|
||
|
loss += self._get_loss(head_outputs, batch)
|
||
|
|
||
|
return head_outputs, loss
|
||
|
|
||
|
def _get_loss(self, head_outputs, batch):
|
||
|
itc_outputs = head_outputs["itc_outputs"]
|
||
|
stc_outputs = head_outputs["stc_outputs"]
|
||
|
el_outputs = head_outputs["el_outputs"]
|
||
|
|
||
|
itc_loss = self._get_itc_loss(itc_outputs, batch)
|
||
|
stc_loss = self._get_stc_loss(stc_outputs, batch)
|
||
|
el_loss = self._get_el_loss(el_outputs, batch, from_key=False)
|
||
|
|
||
|
loss = itc_loss + stc_loss + el_loss
|
||
|
|
||
|
return loss
|