301 lines
11 KiB
Python
301 lines
11 KiB
Python
import os
|
|
import torch
|
|
from torch import nn
|
|
from pathlib import Path
|
|
from transformers import (
|
|
LayoutLMConfig,
|
|
LayoutLMModel,
|
|
LayoutLMTokenizer,
|
|
)
|
|
from transformers import (
|
|
LayoutLMv2Config,
|
|
LayoutLMv2Model,
|
|
LayoutLMv2FeatureExtractor,
|
|
LayoutXLMTokenizer,
|
|
)
|
|
from transformers import (
|
|
XLMRobertaConfig,
|
|
AutoTokenizer,
|
|
XLMRobertaModel
|
|
)
|
|
|
|
# from model import load_checkpoint
|
|
from sdsvkvu.sources.utils import merged_token_embeddings
|
|
from sdsvkvu.model.relation_extractor import RelationExtractor
|
|
|
|
|
|
class KVUModel(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
|
|
self.model_cfg = cfg.model
|
|
self.freeze = cfg.train.freeze
|
|
self.finetune_only = cfg.train.finetune_only
|
|
self.n_classes = len(self.model_cfg.class_names)
|
|
|
|
self._get_backbones(self.model_cfg.backbone)
|
|
self._create_head()
|
|
|
|
# if (cfg.stage == 2) and (os.path.exists(self.model_cfg.ckpt_model_file)):
|
|
# self.backbone_layoutxlm = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone_layoutxlm, 'backbone_layoutxlm')
|
|
|
|
self._create_head()
|
|
self.loss_func = nn.CrossEntropyLoss()
|
|
|
|
if self.freeze:
|
|
for name, param in self.named_parameters():
|
|
if "backbone" in name:
|
|
param.requires_grad = False
|
|
|
|
def _create_head(self):
|
|
self.backbone_hidden_size = 768
|
|
self.head_hidden_size = self.model_cfg.head_hidden_size
|
|
self.head_p_dropout = self.model_cfg.head_p_dropout
|
|
# self.n_classes = self.model_cfg.n_classes + 1
|
|
|
|
# (1) Initial token classification
|
|
self.itc_layer = nn.Sequential(
|
|
nn.Dropout(self.head_p_dropout),
|
|
nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size),
|
|
nn.Dropout(self.head_p_dropout),
|
|
nn.Linear(self.backbone_hidden_size, self.n_classes),
|
|
)
|
|
# (2) Subsequent token classification
|
|
self.stc_layer = RelationExtractor(
|
|
n_relations=1, # 1
|
|
backbone_hidden_size=self.backbone_hidden_size,
|
|
head_hidden_size=self.head_hidden_size,
|
|
head_p_dropout=self.head_p_dropout,
|
|
)
|
|
|
|
# (3) Linking token classification
|
|
self.relation_layer = RelationExtractor(
|
|
n_relations=1, # 1
|
|
backbone_hidden_size=self.backbone_hidden_size,
|
|
head_hidden_size=self.head_hidden_size,
|
|
head_p_dropout=self.head_p_dropout,
|
|
)
|
|
|
|
# (4) Linking token classification
|
|
self.relation_layer_from_key = RelationExtractor(
|
|
n_relations=1, # 1
|
|
backbone_hidden_size=self.backbone_hidden_size,
|
|
head_hidden_size=self.head_hidden_size,
|
|
head_p_dropout=self.head_p_dropout,
|
|
)
|
|
|
|
self.itc_layer.apply(self._init_weight)
|
|
self.stc_layer.apply(self._init_weight)
|
|
self.relation_layer.apply(self._init_weight)
|
|
|
|
# def _get_backbones(self, config_type):
|
|
# self.tokenizer_layoutxlm = LayoutXLMTokenizer.from_pretrained('microsoft/layoutxlm-base')
|
|
# self.feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
|
|
# self.backbone_layoutxlm = LayoutLMv2Model.from_pretrained('microsoft/layoutxlm-base')
|
|
|
|
def _get_backbones(self, config_type):
|
|
configs = {
|
|
"layoutlm": {
|
|
"config": LayoutLMConfig,
|
|
"tokenizer": LayoutLMTokenizer,
|
|
"backbone": LayoutLMModel,
|
|
"feature_extrator": LayoutLMv2FeatureExtractor,
|
|
},
|
|
"layoutxlm": {
|
|
"config": LayoutLMv2Config,
|
|
"tokenizer": LayoutXLMTokenizer,
|
|
"backbone": LayoutLMv2Model,
|
|
"feature_extrator": LayoutLMv2FeatureExtractor,
|
|
},
|
|
"xlm-roberta": {
|
|
"config": XLMRobertaConfig,
|
|
"tokenizer": AutoTokenizer,
|
|
"backbone": XLMRobertaModel,
|
|
"feature_extrator": LayoutLMv2FeatureExtractor,
|
|
},
|
|
}
|
|
|
|
self.backbone_config = configs[config_type]["config"].from_pretrained(
|
|
self.model_cfg.pretrained_model_path
|
|
)
|
|
if config_type != "xlm-roberta":
|
|
self.tokenizer = configs[config_type]["tokenizer"].from_pretrained(
|
|
self.model_cfg.pretrained_model_path
|
|
)
|
|
else:
|
|
self.tokenizer = configs[config_type]["tokenizer"].from_pretrained(
|
|
self.model_cfg.pretrained_model_path, use_fast=False
|
|
)
|
|
self.feature_extractor = configs[config_type]["feature_extrator"](
|
|
apply_ocr=False
|
|
)
|
|
self.backbone = configs[config_type]["backbone"].from_pretrained(
|
|
self.model_cfg.pretrained_model_path
|
|
)
|
|
|
|
@staticmethod
|
|
def _init_weight(module):
|
|
init_std = 0.02
|
|
if isinstance(module, nn.Linear):
|
|
nn.init.normal_(module.weight, 0.0, init_std)
|
|
if module.bias is not None:
|
|
nn.init.constant_(module.bias, 0.0)
|
|
elif isinstance(module, nn.LayerNorm):
|
|
nn.init.normal_(module.weight, 1.0, init_std)
|
|
if module.bias is not None:
|
|
nn.init.constant_(module.bias, 0.0)
|
|
|
|
def forward(self, lbatches):
|
|
windows = lbatches["windows"]
|
|
token_embeddings_windows = []
|
|
lvalids = []
|
|
loverlaps = []
|
|
|
|
for i, batch in enumerate(windows):
|
|
batch = {
|
|
k: v.cuda() for k, v in batch.items() if k not in ("img_path", "words")
|
|
}
|
|
image = batch["image"]
|
|
input_ids_layoutxlm = batch["input_ids_layoutxlm"]
|
|
bbox = batch["bbox"]
|
|
attention_mask_layoutxlm = batch["attention_mask_layoutxlm"]
|
|
|
|
backbone_outputs_layoutxlm = self.backbone_layoutxlm(
|
|
image=image,
|
|
input_ids=input_ids_layoutxlm,
|
|
bbox=bbox,
|
|
attention_mask=attention_mask_layoutxlm,
|
|
)
|
|
|
|
last_hidden_states_layoutxlm = backbone_outputs_layoutxlm.last_hidden_state[
|
|
:, :512, :
|
|
]
|
|
|
|
lvalids.append(batch["len_valid_tokens"])
|
|
loverlaps.append(batch["len_overlap_tokens"])
|
|
token_embeddings_windows.append(last_hidden_states_layoutxlm)
|
|
|
|
token_embeddings = merged_token_embeddings(
|
|
token_embeddings_windows, loverlaps, lvalids, average=False
|
|
)
|
|
|
|
token_embeddings = token_embeddings.transpose(0, 1).contiguous().cuda()
|
|
itc_outputs = self.itc_layer(token_embeddings).transpose(0, 1).contiguous()
|
|
stc_outputs = self.stc_layer(token_embeddings, token_embeddings).squeeze(0)
|
|
el_outputs = self.relation_layer(token_embeddings, token_embeddings).squeeze(0)
|
|
el_outputs_from_key = self.relation_layer_from_key(
|
|
token_embeddings, token_embeddings
|
|
).squeeze(0)
|
|
head_outputs = {
|
|
"itc_outputs": itc_outputs,
|
|
"stc_outputs": stc_outputs,
|
|
"el_outputs": el_outputs,
|
|
"el_outputs_from_key": el_outputs_from_key,
|
|
"embedding_tokens": token_embeddings.transpose(0, 1)
|
|
.contiguous()
|
|
.detach()
|
|
.cpu()
|
|
.numpy(),
|
|
}
|
|
|
|
loss = 0.0
|
|
if any(["labels" in key for key in lbatches.keys()]):
|
|
labels = {
|
|
k: v.cuda()
|
|
for k, v in lbatches["documents"].items()
|
|
if k not in ("img_path")
|
|
}
|
|
loss = self._get_loss(head_outputs, labels)
|
|
|
|
return head_outputs, loss
|
|
|
|
def _get_loss(self, head_outputs, batch):
|
|
itc_outputs = head_outputs["itc_outputs"]
|
|
stc_outputs = head_outputs["stc_outputs"]
|
|
el_outputs = head_outputs["el_outputs"]
|
|
el_outputs_from_key = head_outputs["el_outputs_from_key"]
|
|
|
|
itc_loss = self._get_itc_loss(itc_outputs, batch)
|
|
stc_loss = self._get_stc_loss(stc_outputs, batch)
|
|
el_loss = self._get_el_loss(el_outputs, batch)
|
|
el_loss_from_key = self._get_el_loss(el_outputs_from_key, batch, from_key=True)
|
|
|
|
loss = itc_loss + stc_loss + el_loss + el_loss_from_key
|
|
|
|
return loss
|
|
|
|
def _get_itc_loss(self, itc_outputs, batch):
|
|
itc_mask = batch["are_box_first_tokens"].view(-1)
|
|
|
|
itc_logits = itc_outputs.view(-1, self.model_cfg.n_classes + 1)
|
|
itc_logits = itc_logits[itc_mask]
|
|
|
|
itc_labels = batch["itc_labels"].view(-1)
|
|
itc_labels = itc_labels[itc_mask]
|
|
|
|
itc_loss = self.loss_func(itc_logits, itc_labels)
|
|
|
|
return itc_loss
|
|
|
|
def _get_stc_loss(self, stc_outputs, batch):
|
|
inv_attention_mask = 1 - batch["attention_mask_layoutxlm"]
|
|
|
|
bsz, max_seq_length = inv_attention_mask.shape
|
|
device = inv_attention_mask.device
|
|
|
|
invalid_token_mask = torch.cat(
|
|
[inv_attention_mask, torch.zeros([bsz, 1]).to(device)], axis=1
|
|
).bool()
|
|
|
|
stc_outputs.masked_fill_(invalid_token_mask[:, None, :], -10000.0)
|
|
|
|
self_token_mask = (
|
|
torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
|
|
)
|
|
stc_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
|
|
|
|
stc_mask = batch["attention_mask_layoutxlm"].view(-1).bool()
|
|
stc_logits = stc_outputs.view(-1, max_seq_length + 1)
|
|
stc_logits = stc_logits[stc_mask]
|
|
|
|
stc_labels = batch["stc_labels"].view(-1)
|
|
stc_labels = stc_labels[stc_mask]
|
|
|
|
stc_loss = self.loss_func(stc_logits, stc_labels)
|
|
|
|
return stc_loss
|
|
|
|
def _get_el_loss(self, el_outputs, batch, from_key=False):
|
|
bsz, max_seq_length = batch["attention_mask_layoutxlm"].shape
|
|
|
|
device = batch["attention_mask_layoutxlm"].device
|
|
|
|
self_token_mask = (
|
|
torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
|
|
)
|
|
|
|
box_first_token_mask = torch.cat(
|
|
[
|
|
(batch["are_box_first_tokens"] == False),
|
|
torch.zeros([bsz, 1], dtype=torch.bool).to(device),
|
|
],
|
|
axis=1,
|
|
)
|
|
el_outputs.masked_fill_(box_first_token_mask[:, None, :], -10000.0)
|
|
el_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
|
|
|
|
mask = batch["are_box_first_tokens"].view(-1)
|
|
|
|
logits = el_outputs.view(-1, max_seq_length + 1)
|
|
logits = logits[mask]
|
|
|
|
if from_key:
|
|
el_labels = batch["el_labels_from_key"]
|
|
else:
|
|
el_labels = batch["el_labels"]
|
|
labels = el_labels.view(-1)
|
|
labels = labels[mask]
|
|
|
|
loss = self.loss_func(logits, labels)
|
|
return loss
|