77 lines
3.5 KiB
Python
Executable File
77 lines
3.5 KiB
Python
Executable File
import os
|
|
import torch
|
|
from torch import nn
|
|
from transformers import LayoutLMv2Model, LayoutLMv2FeatureExtractor
|
|
from transformers import LayoutXLMTokenizer
|
|
from transformers import AutoTokenizer, XLMRobertaModel
|
|
|
|
|
|
from model.relation_extractor import RelationExtractor
|
|
from model.kvu_model import KVUModel
|
|
from utils import load_checkpoint
|
|
|
|
|
|
class CombinedKVUModel(KVUModel):
|
|
def __init__(self, cfg):
|
|
super().__init__(cfg)
|
|
|
|
self.model_cfg = cfg.model
|
|
self.freeze = cfg.train.freeze
|
|
self.finetune_only = cfg.train.finetune_only
|
|
|
|
self._get_backbones(self.model_cfg.backbone)
|
|
self._create_head()
|
|
|
|
if os.path.exists(self.model_cfg.ckpt_model_file):
|
|
self.backbone_layoutxlm = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone_layoutxlm, 'backbone_layoutxlm')
|
|
self.itc_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.itc_layer, 'itc_layer')
|
|
self.stc_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.stc_layer, 'stc_layer')
|
|
self.relation_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.relation_layer, 'relation_layer')
|
|
self.relation_layer_from_key = load_checkpoint(self.model_cfg.ckpt_model_file, self.relation_layer_from_key, 'relation_layer_from_key')
|
|
|
|
self.loss_func = nn.CrossEntropyLoss()
|
|
|
|
if self.freeze:
|
|
for name, param in self.named_parameters():
|
|
if 'backbone' in name:
|
|
param.requires_grad = False
|
|
if self.finetune_only == 'EE':
|
|
for name, param in self.named_parameters():
|
|
if 'itc_layer' not in name and 'stc_layer' not in name:
|
|
param.requires_grad = False
|
|
if self.finetune_only == 'EL':
|
|
for name, param in self.named_parameters():
|
|
if 'relation_layer' not in name or 'relation_layer_from_key' in name:
|
|
param.requires_grad = False
|
|
if self.finetune_only == 'ELK':
|
|
for name, param in self.named_parameters():
|
|
if 'relation_layer_from_key' not in name:
|
|
param.requires_grad = False
|
|
|
|
|
|
def forward(self, batch):
|
|
image = batch["image"]
|
|
input_ids_layoutxlm = batch["input_ids_layoutxlm"]
|
|
bbox = batch["bbox"]
|
|
attention_mask_layoutxlm = batch["attention_mask_layoutxlm"]
|
|
|
|
backbone_outputs_layoutxlm = self.backbone_layoutxlm(
|
|
image=image, input_ids=input_ids_layoutxlm, bbox=bbox, attention_mask=attention_mask_layoutxlm)
|
|
|
|
last_hidden_states = backbone_outputs_layoutxlm.last_hidden_state[:, :512, :]
|
|
last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
|
|
|
|
itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous()
|
|
stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0)
|
|
el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0)
|
|
el_outputs_from_key = self.relation_layer_from_key(last_hidden_states, last_hidden_states).squeeze(0)
|
|
head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs,
|
|
"el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key}
|
|
|
|
loss = 0.0
|
|
if any(['labels' in key for key in batch.keys()]):
|
|
loss = self._get_loss(head_outputs, batch)
|
|
|
|
return head_outputs, loss
|
|
|