sbt-idp/cope2n-ai-fi/common/AnyKey_Value/model/kvu_model.py
2023-11-30 18:22:16 +07:00

249 lines
9.9 KiB
Python
Executable File

import os
import torch
from torch import nn
from transformers import LayoutLMv2Model, LayoutLMv2FeatureExtractor
from transformers import LayoutXLMTokenizer
from lightning_modules.utils import merged_token_embeddings, merged_token_embeddings2
from model.relation_extractor import RelationExtractor
from utils import load_checkpoint
class KVUModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.device = 'cuda'
self.model_cfg = cfg.model
self.freeze = cfg.train.freeze
self.finetune_only = cfg.train.finetune_only
# if cfg.stage == 2:
# self.freeze = True
self._get_backbones(self.model_cfg.backbone)
self._create_head()
if (cfg.stage == 2) and (os.path.exists(self.model_cfg.ckpt_model_file)):
self.backbone_layoutxlm = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone_layoutxlm, 'backbone_layoutxlm')
self._create_head()
self.loss_func = nn.CrossEntropyLoss()
if self.freeze:
for name, param in self.named_parameters():
if 'backbone' in name:
param.requires_grad = False
def _create_head(self):
self.backbone_hidden_size = 768
self.head_hidden_size = self.model_cfg.head_hidden_size
self.head_p_dropout = self.model_cfg.head_p_dropout
self.n_classes = self.model_cfg.n_classes + 1
# (1) Initial token classification
self.itc_layer = nn.Sequential(
nn.Dropout(self.head_p_dropout),
nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size),
nn.Dropout(self.head_p_dropout),
nn.Linear(self.backbone_hidden_size, self.n_classes),
)
# (2) Subsequent token classification
self.stc_layer = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (3) Linking token classification
self.relation_layer = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (4) Linking token classification
self.relation_layer_from_key = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
self.itc_layer.apply(self._init_weight)
self.stc_layer.apply(self._init_weight)
self.relation_layer.apply(self._init_weight)
def _get_backbones(self, config_type):
self.tokenizer_layoutxlm = LayoutXLMTokenizer.from_pretrained('microsoft/layoutxlm-base')
self.feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
self.backbone_layoutxlm = LayoutLMv2Model.from_pretrained('microsoft/layoutxlm-base')
@staticmethod
def _init_weight(module):
init_std = 0.02
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, 0.0, init_std)
if module.bias is not None:
nn.init.constant_(module.bias, 0.0)
elif isinstance(module, nn.LayerNorm):
nn.init.normal_(module.weight, 1.0, init_std)
if module.bias is not None:
nn.init.constant_(module.bias, 0.0)
# def forward(self, inputs):
# token_embeddings = inputs['embeddings'].transpose(0, 1).contiguous().cuda()
# itc_outputs = self.itc_layer(token_embeddings).transpose(0, 1).contiguous()
# stc_outputs = self.stc_layer(token_embeddings, token_embeddings).squeeze(0)
# el_outputs = self.relation_layer(token_embeddings, token_embeddings).squeeze(0)
# el_outputs_from_key = self.relation_layer_from_key(token_embeddings, token_embeddings).squeeze(0)
# head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs,
# "el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key}
# loss = self._get_loss(head_outputs, inputs)
# return head_outputs, loss
# def forward_single_doccument(self, lbatches):
def forward(self, lbatches):
windows = lbatches['windows']
token_embeddings_windows = []
lvalids = []
loverlaps = []
for i, batch in enumerate(windows):
batch = {k: v.cuda() for k, v in batch.items() if k not in ('img_path', 'words')}
image = batch["image"]
input_ids_layoutxlm = batch["input_ids_layoutxlm"]
bbox = batch["bbox"]
attention_mask_layoutxlm = batch["attention_mask_layoutxlm"]
backbone_outputs_layoutxlm = self.backbone_layoutxlm(
image=image, input_ids=input_ids_layoutxlm, bbox=bbox, attention_mask=attention_mask_layoutxlm)
last_hidden_states_layoutxlm = backbone_outputs_layoutxlm.last_hidden_state[:, :512, :]
lvalids.append(batch['len_valid_tokens'])
loverlaps.append(batch['len_overlap_tokens'])
token_embeddings_windows.append(last_hidden_states_layoutxlm)
token_embeddings = merged_token_embeddings2(token_embeddings_windows, loverlaps, lvalids, average=False)
# token_embeddings = merged_token_embeddings(token_embeddings_windows, loverlaps, lvalids, average=True)
token_embeddings = token_embeddings.transpose(0, 1).contiguous().cuda()
itc_outputs = self.itc_layer(token_embeddings).transpose(0, 1).contiguous()
stc_outputs = self.stc_layer(token_embeddings, token_embeddings).squeeze(0)
el_outputs = self.relation_layer(token_embeddings, token_embeddings).squeeze(0)
el_outputs_from_key = self.relation_layer_from_key(token_embeddings, token_embeddings).squeeze(0)
head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs,
"el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key,
'embedding_tokens': token_embeddings.transpose(0, 1).contiguous().detach().cpu().numpy()}
loss = 0.0
if any(['labels' in key for key in lbatches.keys()]):
labels = {k: v.cuda() for k, v in lbatches["documents"].items() if k not in ('img_path')}
loss = self._get_loss(head_outputs, labels)
return head_outputs, loss
def _get_loss(self, head_outputs, batch):
itc_outputs = head_outputs["itc_outputs"]
stc_outputs = head_outputs["stc_outputs"]
el_outputs = head_outputs["el_outputs"]
el_outputs_from_key = head_outputs["el_outputs_from_key"]
itc_loss = self._get_itc_loss(itc_outputs, batch)
stc_loss = self._get_stc_loss(stc_outputs, batch)
el_loss = self._get_el_loss(el_outputs, batch)
el_loss_from_key = self._get_el_loss(el_outputs_from_key, batch, from_key=True)
loss = itc_loss + stc_loss + el_loss + el_loss_from_key
return loss
def _get_itc_loss(self, itc_outputs, batch):
itc_mask = batch["are_box_first_tokens"].view(-1)
itc_logits = itc_outputs.view(-1, self.model_cfg.n_classes + 1)
itc_logits = itc_logits[itc_mask]
itc_labels = batch["itc_labels"].view(-1)
itc_labels = itc_labels[itc_mask]
itc_loss = self.loss_func(itc_logits, itc_labels)
return itc_loss
def _get_stc_loss(self, stc_outputs, batch):
inv_attention_mask = 1 - batch["attention_mask_layoutxlm"]
bsz, max_seq_length = inv_attention_mask.shape
device = inv_attention_mask.device
invalid_token_mask = torch.cat(
[inv_attention_mask, torch.zeros([bsz, 1]).to(device)], axis=1
).bool()
stc_outputs.masked_fill_(invalid_token_mask[:, None, :], -10000.0)
self_token_mask = (
torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
)
stc_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
stc_mask = batch["attention_mask_layoutxlm"].view(-1).bool()
stc_logits = stc_outputs.view(-1, max_seq_length + 1)
stc_logits = stc_logits[stc_mask]
stc_labels = batch["stc_labels"].view(-1)
stc_labels = stc_labels[stc_mask]
stc_loss = self.loss_func(stc_logits, stc_labels)
return stc_loss
def _get_el_loss(self, el_outputs, batch, from_key=False):
bsz, max_seq_length = batch["attention_mask_layoutxlm"].shape
device = batch["attention_mask_layoutxlm"].device
self_token_mask = (
torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
)
box_first_token_mask = torch.cat(
[
(batch["are_box_first_tokens"] == False),
torch.zeros([bsz, 1], dtype=torch.bool).to(device),
],
axis=1,
)
el_outputs.masked_fill_(box_first_token_mask[:, None, :], -10000.0)
el_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
mask = batch["are_box_first_tokens"].view(-1)
logits = el_outputs.view(-1, max_seq_length + 1)
logits = logits[mask]
if from_key:
el_labels = batch["el_labels_from_key"]
else:
el_labels = batch["el_labels"]
labels = el_labels.view(-1)
labels = labels[mask]
loss = self.loss_func(logits, labels)
return loss