import torch from torch import nn from transformers import LayoutLMConfig, LayoutLMModel, LayoutLMTokenizer, LayoutLMv2FeatureExtractor from transformers import LayoutLMv2Config, LayoutLMv2Model from sdsvkvu.model.relation_extractor import RelationExtractor from sdsvkvu.model.kvu_model import KVUModel # from utils import load_checkpoint class SBTModel(KVUModel): def __init__(self, cfg): super().__init__(cfg=cfg) self.model_cfg = cfg.model self.freeze = cfg.train.freeze self.train_cfg = cfg.train self.n_classes = len(self.model_cfg.class_names) self._get_backbones(self.model_cfg.backbone) self._create_head() self.loss_func = nn.CrossEntropyLoss() def _create_head(self): self.backbone_hidden_size = self.backbone_config.hidden_size self.head_hidden_size = self.model_cfg.head_hidden_size self.head_p_dropout = self.model_cfg.head_p_dropout # self.n_classes = self.model_cfg.n_classes + 1 # self.relations = self.model_cfg.n_relations self.repr_hiddent_size = self.backbone_hidden_size # (1) Initial token classification self.itc_layer = nn.Sequential( nn.Dropout(self.head_p_dropout), nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size), nn.Dropout(self.head_p_dropout), nn.Linear(self.backbone_hidden_size, self.n_classes), ) # (2) Subsequent token classification self.stc_layer = RelationExtractor( n_relations=1, #1 backbone_hidden_size=self.backbone_hidden_size, head_hidden_size=self.head_hidden_size, head_p_dropout=self.head_p_dropout, ) # (3) Linking token classification self.relation_layer = RelationExtractor( n_relations=1, #1 backbone_hidden_size=self.backbone_hidden_size, head_hidden_size=self.head_hidden_size, head_p_dropout=self.head_p_dropout, ) # Classfication Layer for whole document self.itc_layer_document = nn.Sequential( nn.Dropout(self.head_p_dropout), nn.Linear(self.repr_hiddent_size, self.repr_hiddent_size), nn.Dropout(self.head_p_dropout), nn.Linear(self.repr_hiddent_size, self.n_classes), ) self.stc_layer_document = RelationExtractor( n_relations=1, backbone_hidden_size=self.repr_hiddent_size, head_hidden_size=self.head_hidden_size, head_p_dropout=self.head_p_dropout, ) self.relation_layer_document = RelationExtractor( n_relations=1, backbone_hidden_size=self.repr_hiddent_size, head_hidden_size=self.head_hidden_size, head_p_dropout=self.head_p_dropout, ) self.itc_layer.apply(self._init_weight) self.stc_layer.apply(self._init_weight) self.relation_layer.apply(self._init_weight) self.itc_layer_document.apply(self._init_weight) self.stc_layer_document.apply(self._init_weight) self.relation_layer_document.apply(self._init_weight) def forward(self, batches): head_outputs_list = [] loss = 0. for batch in batches["windows"]: image = batch["image"] input_ids = batch["input_ids_layoutxlm"] bbox = batch["bbox"] attention_mask = batch["attention_mask_layoutxlm"] if self.freeze: for param in self.backbone.parameters(): param.requires_grad = False if self.model_cfg.backbone == 'layoutxlm': backbone_outputs = self.backbone( image=image, input_ids=input_ids, bbox=bbox, attention_mask=attention_mask ) else: backbone_outputs = self.backbone(input_ids, attention_mask=attention_mask) last_hidden_states = backbone_outputs.last_hidden_state[:, :512, :] last_hidden_states = last_hidden_states.transpose(0, 1).contiguous() itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous() stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0) el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0) window_repr = last_hidden_states.transpose(0, 1).contiguous() head_outputs = {"window_repr": window_repr, "itc_outputs": itc_outputs, "stc_outputs": stc_outputs, "el_outputs": el_outputs,} if any(['labels' in key for key in batch.keys()]): loss += self._get_loss(head_outputs, batch) head_outputs_list.append(head_outputs) batch = batches["documents"] document_repr = torch.cat([w['window_repr'] for w in head_outputs_list], dim=1) document_repr = document_repr.transpose(0, 1).contiguous() itc_outputs = self.itc_layer_document(document_repr).transpose(0, 1).contiguous() stc_outputs = self.stc_layer_document(document_repr, document_repr).squeeze(0) el_outputs = self.relation_layer_document(document_repr, document_repr).squeeze(0) head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs, "el_outputs": el_outputs} if any(['labels' in key for key in batch.keys()]): loss += self._get_loss(head_outputs, batch) return head_outputs, loss def _get_loss(self, head_outputs, batch): itc_outputs = head_outputs["itc_outputs"] stc_outputs = head_outputs["stc_outputs"] el_outputs = head_outputs["el_outputs"] itc_loss = self._get_itc_loss(itc_outputs, batch) stc_loss = self._get_stc_loss(stc_outputs, batch) el_loss = self._get_el_loss(el_outputs, batch, from_key=False) loss = itc_loss + stc_loss + el_loss return loss