import torch from torch import nn from transformers import LayoutLMConfig, LayoutLMModel, LayoutLMTokenizer, LayoutLMv2FeatureExtractor from transformers import LayoutLMv2Config, LayoutLMv2Model from transformers import LayoutXLMTokenizer from transformers import XLMRobertaConfig, AutoTokenizer, XLMRobertaModel from model.relation_extractor import RelationExtractor from model.kvu_model import KVUModel from utils import load_checkpoint class DocumentKVUModel(KVUModel): def __init__(self, cfg): super().__init__(cfg) self.model_cfg = cfg.model self.freeze = cfg.train.freeze self.train_cfg = cfg.train self._get_backbones(self.model_cfg.backbone) # if 'pth' in self.model_cfg.ckpt_model_file: # self.backbone = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone) self._create_head() self.loss_func = nn.CrossEntropyLoss() def _create_head(self): self.backbone_hidden_size = self.backbone_config.hidden_size self.head_hidden_size = self.model_cfg.head_hidden_size self.head_p_dropout = self.model_cfg.head_p_dropout self.n_classes = self.model_cfg.n_classes + 1 self.repr_hiddent_size = self.backbone_hidden_size # (1) Initial token classification self.itc_layer = nn.Sequential( nn.Dropout(self.head_p_dropout), nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size), nn.Dropout(self.head_p_dropout), nn.Linear(self.backbone_hidden_size, self.n_classes), ) # (2) Subsequent token classification self.stc_layer = RelationExtractor( n_relations=1, #1 backbone_hidden_size=self.backbone_hidden_size, head_hidden_size=self.head_hidden_size, head_p_dropout=self.head_p_dropout, ) # (3) Linking token classification self.relation_layer = RelationExtractor( n_relations=1, #1 backbone_hidden_size=self.backbone_hidden_size, head_hidden_size=self.head_hidden_size, head_p_dropout=self.head_p_dropout, ) # (4) Linking token classification self.relation_layer_from_key = RelationExtractor( n_relations=1, #1 backbone_hidden_size=self.backbone_hidden_size, head_hidden_size=self.head_hidden_size, head_p_dropout=self.head_p_dropout, ) # Classfication Layer for whole document # (1) Initial token classification self.itc_layer_document = nn.Sequential( nn.Dropout(self.head_p_dropout), nn.Linear(self.repr_hiddent_size, self.repr_hiddent_size), nn.Dropout(self.head_p_dropout), nn.Linear(self.repr_hiddent_size, self.n_classes), ) # (2) Subsequent token classification self.stc_layer_document = RelationExtractor( n_relations=1, backbone_hidden_size=self.repr_hiddent_size, head_hidden_size=self.head_hidden_size, head_p_dropout=self.head_p_dropout, ) # (3) Linking token classification self.relation_layer_document = RelationExtractor( n_relations=1, backbone_hidden_size=self.repr_hiddent_size, head_hidden_size=self.head_hidden_size, head_p_dropout=self.head_p_dropout, ) # (4) Linking token classification self.relation_layer_from_key_document = RelationExtractor( n_relations=1, backbone_hidden_size=self.repr_hiddent_size, head_hidden_size=self.head_hidden_size, head_p_dropout=self.head_p_dropout, ) self.itc_layer.apply(self._init_weight) self.stc_layer.apply(self._init_weight) self.relation_layer.apply(self._init_weight) self.relation_layer_from_key.apply(self._init_weight) self.itc_layer_document.apply(self._init_weight) self.stc_layer_document.apply(self._init_weight) self.relation_layer_document.apply(self._init_weight) self.relation_layer_from_key_document.apply(self._init_weight) def _get_backbones(self, config_type): configs = { 'layoutlm': {'config': LayoutLMConfig, 'tokenizer': LayoutLMTokenizer, 'backbone': LayoutLMModel, 'feature_extrator': LayoutLMv2FeatureExtractor}, 'layoutxlm': {'config': LayoutLMv2Config, 'tokenizer': LayoutXLMTokenizer, 'backbone': LayoutLMv2Model, 'feature_extrator': LayoutLMv2FeatureExtractor}, 'xlm-roberta': {'config': XLMRobertaConfig, 'tokenizer': AutoTokenizer, 'backbone': XLMRobertaModel, 'feature_extrator': LayoutLMv2FeatureExtractor}, } self.backbone_config = configs[config_type]['config'].from_pretrained(self.model_cfg.pretrained_model_path) if config_type != 'xlm-roberta': self.tokenizer = configs[config_type]['tokenizer'].from_pretrained(self.model_cfg.pretrained_model_path) else: self.tokenizer = configs[config_type]['tokenizer'].from_pretrained(self.model_cfg.pretrained_model_path, use_fast=False) self.feature_extractor = configs[config_type]['feature_extrator'](apply_ocr=False) self.backbone = configs[config_type]['backbone'].from_pretrained(self.model_cfg.pretrained_model_path) def forward(self, batches): head_outputs_list = [] loss = 0.0 for batch in batches["windows"]: image = batch["image"] input_ids = batch["input_ids_layoutxlm"] bbox = batch["bbox"] attention_mask = batch["attention_mask_layoutxlm"] if self.freeze: for param in self.backbone.parameters(): param.requires_grad = False if self.model_cfg.backbone == 'layoutxlm': backbone_outputs = self.backbone( image=image, input_ids=input_ids, bbox=bbox, attention_mask=attention_mask ) else: backbone_outputs = self.backbone(input_ids, attention_mask=attention_mask) last_hidden_states = backbone_outputs.last_hidden_state[:, :512, :] last_hidden_states = last_hidden_states.transpose(0, 1).contiguous() itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous() stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0) el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0) el_outputs_from_key = self.relation_layer_from_key(last_hidden_states, last_hidden_states).squeeze(0) window_repr = last_hidden_states.transpose(0, 1).contiguous() head_outputs = {"window_repr": window_repr, "itc_outputs": itc_outputs, "stc_outputs": stc_outputs, "el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key} if any(['labels' in key for key in batch.keys()]): loss += self._get_loss(head_outputs, batch) head_outputs_list.append(head_outputs) batch = batches["documents"] document_repr = torch.cat([w['window_repr'] for w in head_outputs_list], dim=1) document_repr = document_repr.transpose(0, 1).contiguous() itc_outputs = self.itc_layer_document(document_repr).transpose(0, 1).contiguous() stc_outputs = self.stc_layer_document(document_repr, document_repr).squeeze(0) el_outputs = self.relation_layer_document(document_repr, document_repr).squeeze(0) el_outputs_from_key = self.relation_layer_from_key_document(document_repr, document_repr).squeeze(0) head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs, "el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key} if any(['labels' in key for key in batch.keys()]): loss += self._get_loss(head_outputs, batch) return head_outputs, loss