import os import torch from torch import nn from transformers import LayoutLMv2Model, LayoutLMv2FeatureExtractor from transformers import LayoutXLMTokenizer from transformers import AutoTokenizer, XLMRobertaModel from model.relation_extractor import RelationExtractor from model.kvu_model import KVUModel from utils import load_checkpoint class CombinedKVUModel(KVUModel): def __init__(self, cfg): super().__init__(cfg) self.model_cfg = cfg.model self.freeze = cfg.train.freeze self.finetune_only = cfg.train.finetune_only self._get_backbones(self.model_cfg.backbone) self._create_head() if os.path.exists(self.model_cfg.ckpt_model_file): self.backbone_layoutxlm = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone_layoutxlm, 'backbone_layoutxlm') self.itc_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.itc_layer, 'itc_layer') self.stc_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.stc_layer, 'stc_layer') self.relation_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.relation_layer, 'relation_layer') self.relation_layer_from_key = load_checkpoint(self.model_cfg.ckpt_model_file, self.relation_layer_from_key, 'relation_layer_from_key') self.loss_func = nn.CrossEntropyLoss() if self.freeze: for name, param in self.named_parameters(): if 'backbone' in name: param.requires_grad = False if self.finetune_only == 'EE': for name, param in self.named_parameters(): if 'itc_layer' not in name and 'stc_layer' not in name: param.requires_grad = False if self.finetune_only == 'EL': for name, param in self.named_parameters(): if 'relation_layer' not in name or 'relation_layer_from_key' in name: param.requires_grad = False if self.finetune_only == 'ELK': for name, param in self.named_parameters(): if 'relation_layer_from_key' not in name: param.requires_grad = False def forward(self, batch): image = batch["image"] input_ids_layoutxlm = batch["input_ids_layoutxlm"] bbox = batch["bbox"] attention_mask_layoutxlm = batch["attention_mask_layoutxlm"] backbone_outputs_layoutxlm = self.backbone_layoutxlm( image=image, input_ids=input_ids_layoutxlm, bbox=bbox, attention_mask=attention_mask_layoutxlm) last_hidden_states = backbone_outputs_layoutxlm.last_hidden_state[:, :512, :] last_hidden_states = last_hidden_states.transpose(0, 1).contiguous() itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous() stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0) el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0) el_outputs_from_key = self.relation_layer_from_key(last_hidden_states, last_hidden_states).squeeze(0) head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs, "el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key} loss = 0.0 if any(['labels' in key for key in batch.keys()]): loss = self._get_loss(head_outputs, batch) return head_outputs, loss