import os import torch from torch import nn from pathlib import Path from transformers import ( LayoutLMConfig, LayoutLMModel, LayoutLMTokenizer, ) from transformers import ( LayoutLMv2Config, LayoutLMv2Model, LayoutLMv2FeatureExtractor, LayoutXLMTokenizer, ) from transformers import ( XLMRobertaConfig, AutoTokenizer, XLMRobertaModel ) # from model import load_checkpoint from sdsvkvu.sources.utils import merged_token_embeddings from sdsvkvu.model.relation_extractor import RelationExtractor class KVUModel(nn.Module): def __init__(self, cfg): super().__init__() self.model_cfg = cfg.model self.freeze = cfg.train.freeze self.finetune_only = cfg.train.finetune_only self.n_classes = len(self.model_cfg.class_names) self._get_backbones(self.model_cfg.backbone) self._create_head() # if (cfg.stage == 2) and (os.path.exists(self.model_cfg.ckpt_model_file)): # self.backbone_layoutxlm = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone_layoutxlm, 'backbone_layoutxlm') self._create_head() self.loss_func = nn.CrossEntropyLoss() if self.freeze: for name, param in self.named_parameters(): if "backbone" in name: param.requires_grad = False def _create_head(self): self.backbone_hidden_size = 768 self.head_hidden_size = self.model_cfg.head_hidden_size self.head_p_dropout = self.model_cfg.head_p_dropout # self.n_classes = self.model_cfg.n_classes + 1 # (1) Initial token classification self.itc_layer = nn.Sequential( nn.Dropout(self.head_p_dropout), nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size), nn.Dropout(self.head_p_dropout), nn.Linear(self.backbone_hidden_size, self.n_classes), ) # (2) Subsequent token classification self.stc_layer = RelationExtractor( n_relations=1, # 1 backbone_hidden_size=self.backbone_hidden_size, head_hidden_size=self.head_hidden_size, head_p_dropout=self.head_p_dropout, ) # (3) Linking token classification self.relation_layer = RelationExtractor( n_relations=1, # 1 backbone_hidden_size=self.backbone_hidden_size, head_hidden_size=self.head_hidden_size, head_p_dropout=self.head_p_dropout, ) # (4) Linking token classification self.relation_layer_from_key = RelationExtractor( n_relations=1, # 1 backbone_hidden_size=self.backbone_hidden_size, head_hidden_size=self.head_hidden_size, head_p_dropout=self.head_p_dropout, ) self.itc_layer.apply(self._init_weight) self.stc_layer.apply(self._init_weight) self.relation_layer.apply(self._init_weight) # def _get_backbones(self, config_type): # self.tokenizer_layoutxlm = LayoutXLMTokenizer.from_pretrained('microsoft/layoutxlm-base') # self.feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False) # self.backbone_layoutxlm = LayoutLMv2Model.from_pretrained('microsoft/layoutxlm-base') def _get_backbones(self, config_type): configs = { "layoutlm": { "config": LayoutLMConfig, "tokenizer": LayoutLMTokenizer, "backbone": LayoutLMModel, "feature_extrator": LayoutLMv2FeatureExtractor, }, "layoutxlm": { "config": LayoutLMv2Config, "tokenizer": LayoutXLMTokenizer, "backbone": LayoutLMv2Model, "feature_extrator": LayoutLMv2FeatureExtractor, }, "xlm-roberta": { "config": XLMRobertaConfig, "tokenizer": AutoTokenizer, "backbone": XLMRobertaModel, "feature_extrator": LayoutLMv2FeatureExtractor, }, } self.backbone_config = configs[config_type]["config"].from_pretrained( self.model_cfg.pretrained_model_path ) if config_type != "xlm-roberta": self.tokenizer = configs[config_type]["tokenizer"].from_pretrained( self.model_cfg.pretrained_model_path ) else: self.tokenizer = configs[config_type]["tokenizer"].from_pretrained( self.model_cfg.pretrained_model_path, use_fast=False ) self.feature_extractor = configs[config_type]["feature_extrator"]( apply_ocr=False ) self.backbone = configs[config_type]["backbone"].from_pretrained( self.model_cfg.pretrained_model_path ) @staticmethod def _init_weight(module): init_std = 0.02 if isinstance(module, nn.Linear): nn.init.normal_(module.weight, 0.0, init_std) if module.bias is not None: nn.init.constant_(module.bias, 0.0) elif isinstance(module, nn.LayerNorm): nn.init.normal_(module.weight, 1.0, init_std) if module.bias is not None: nn.init.constant_(module.bias, 0.0) def forward(self, lbatches): windows = lbatches["windows"] token_embeddings_windows = [] lvalids = [] loverlaps = [] for i, batch in enumerate(windows): batch = { k: v.cuda() for k, v in batch.items() if k not in ("img_path", "words") } image = batch["image"] input_ids_layoutxlm = batch["input_ids_layoutxlm"] bbox = batch["bbox"] attention_mask_layoutxlm = batch["attention_mask_layoutxlm"] backbone_outputs_layoutxlm = self.backbone_layoutxlm( image=image, input_ids=input_ids_layoutxlm, bbox=bbox, attention_mask=attention_mask_layoutxlm, ) last_hidden_states_layoutxlm = backbone_outputs_layoutxlm.last_hidden_state[ :, :512, : ] lvalids.append(batch["len_valid_tokens"]) loverlaps.append(batch["len_overlap_tokens"]) token_embeddings_windows.append(last_hidden_states_layoutxlm) token_embeddings = merged_token_embeddings( token_embeddings_windows, loverlaps, lvalids, average=False ) token_embeddings = token_embeddings.transpose(0, 1).contiguous().cuda() itc_outputs = self.itc_layer(token_embeddings).transpose(0, 1).contiguous() stc_outputs = self.stc_layer(token_embeddings, token_embeddings).squeeze(0) el_outputs = self.relation_layer(token_embeddings, token_embeddings).squeeze(0) el_outputs_from_key = self.relation_layer_from_key( token_embeddings, token_embeddings ).squeeze(0) head_outputs = { "itc_outputs": itc_outputs, "stc_outputs": stc_outputs, "el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key, "embedding_tokens": token_embeddings.transpose(0, 1) .contiguous() .detach() .cpu() .numpy(), } loss = 0.0 if any(["labels" in key for key in lbatches.keys()]): labels = { k: v.cuda() for k, v in lbatches["documents"].items() if k not in ("img_path") } loss = self._get_loss(head_outputs, labels) return head_outputs, loss def _get_loss(self, head_outputs, batch): itc_outputs = head_outputs["itc_outputs"] stc_outputs = head_outputs["stc_outputs"] el_outputs = head_outputs["el_outputs"] el_outputs_from_key = head_outputs["el_outputs_from_key"] itc_loss = self._get_itc_loss(itc_outputs, batch) stc_loss = self._get_stc_loss(stc_outputs, batch) el_loss = self._get_el_loss(el_outputs, batch) el_loss_from_key = self._get_el_loss(el_outputs_from_key, batch, from_key=True) loss = itc_loss + stc_loss + el_loss + el_loss_from_key return loss def _get_itc_loss(self, itc_outputs, batch): itc_mask = batch["are_box_first_tokens"].view(-1) itc_logits = itc_outputs.view(-1, self.model_cfg.n_classes + 1) itc_logits = itc_logits[itc_mask] itc_labels = batch["itc_labels"].view(-1) itc_labels = itc_labels[itc_mask] itc_loss = self.loss_func(itc_logits, itc_labels) return itc_loss def _get_stc_loss(self, stc_outputs, batch): inv_attention_mask = 1 - batch["attention_mask_layoutxlm"] bsz, max_seq_length = inv_attention_mask.shape device = inv_attention_mask.device invalid_token_mask = torch.cat( [inv_attention_mask, torch.zeros([bsz, 1]).to(device)], axis=1 ).bool() stc_outputs.masked_fill_(invalid_token_mask[:, None, :], -10000.0) self_token_mask = ( torch.eye(max_seq_length, max_seq_length + 1).to(device).bool() ) stc_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0) stc_mask = batch["attention_mask_layoutxlm"].view(-1).bool() stc_logits = stc_outputs.view(-1, max_seq_length + 1) stc_logits = stc_logits[stc_mask] stc_labels = batch["stc_labels"].view(-1) stc_labels = stc_labels[stc_mask] stc_loss = self.loss_func(stc_logits, stc_labels) return stc_loss def _get_el_loss(self, el_outputs, batch, from_key=False): bsz, max_seq_length = batch["attention_mask_layoutxlm"].shape device = batch["attention_mask_layoutxlm"].device self_token_mask = ( torch.eye(max_seq_length, max_seq_length + 1).to(device).bool() ) box_first_token_mask = torch.cat( [ (batch["are_box_first_tokens"] == False), torch.zeros([bsz, 1], dtype=torch.bool).to(device), ], axis=1, ) el_outputs.masked_fill_(box_first_token_mask[:, None, :], -10000.0) el_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0) mask = batch["are_box_first_tokens"].view(-1) logits = el_outputs.view(-1, max_seq_length + 1) logits = logits[mask] if from_key: el_labels = batch["el_labels_from_key"] else: el_labels = batch["el_labels"] labels = el_labels.view(-1) labels = labels[mask] loss = self.loss_func(logits, labels) return loss