sbt-idp/cope2n-ai-fi/common/AnyKey_Value/model/combined_model.py
2023-11-30 18:22:16 +07:00

77 lines
3.5 KiB
Python
Executable File

import os
import torch
from torch import nn
from transformers import LayoutLMv2Model, LayoutLMv2FeatureExtractor
from transformers import LayoutXLMTokenizer
from transformers import AutoTokenizer, XLMRobertaModel
from model.relation_extractor import RelationExtractor
from model.kvu_model import KVUModel
from utils import load_checkpoint
class CombinedKVUModel(KVUModel):
def __init__(self, cfg):
super().__init__(cfg)
self.model_cfg = cfg.model
self.freeze = cfg.train.freeze
self.finetune_only = cfg.train.finetune_only
self._get_backbones(self.model_cfg.backbone)
self._create_head()
if os.path.exists(self.model_cfg.ckpt_model_file):
self.backbone_layoutxlm = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone_layoutxlm, 'backbone_layoutxlm')
self.itc_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.itc_layer, 'itc_layer')
self.stc_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.stc_layer, 'stc_layer')
self.relation_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.relation_layer, 'relation_layer')
self.relation_layer_from_key = load_checkpoint(self.model_cfg.ckpt_model_file, self.relation_layer_from_key, 'relation_layer_from_key')
self.loss_func = nn.CrossEntropyLoss()
if self.freeze:
for name, param in self.named_parameters():
if 'backbone' in name:
param.requires_grad = False
if self.finetune_only == 'EE':
for name, param in self.named_parameters():
if 'itc_layer' not in name and 'stc_layer' not in name:
param.requires_grad = False
if self.finetune_only == 'EL':
for name, param in self.named_parameters():
if 'relation_layer' not in name or 'relation_layer_from_key' in name:
param.requires_grad = False
if self.finetune_only == 'ELK':
for name, param in self.named_parameters():
if 'relation_layer_from_key' not in name:
param.requires_grad = False
def forward(self, batch):
image = batch["image"]
input_ids_layoutxlm = batch["input_ids_layoutxlm"]
bbox = batch["bbox"]
attention_mask_layoutxlm = batch["attention_mask_layoutxlm"]
backbone_outputs_layoutxlm = self.backbone_layoutxlm(
image=image, input_ids=input_ids_layoutxlm, bbox=bbox, attention_mask=attention_mask_layoutxlm)
last_hidden_states = backbone_outputs_layoutxlm.last_hidden_state[:, :512, :]
last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous()
stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0)
el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0)
el_outputs_from_key = self.relation_layer_from_key(last_hidden_states, last_hidden_states).squeeze(0)
head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs,
"el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key}
loss = 0.0
if any(['labels' in key for key in batch.keys()]):
loss = self._get_loss(head_outputs, batch)
return head_outputs, loss