from torch import nn import torch from transformers import ( LayoutXLMTokenizer, LayoutLMv2FeatureExtractor, LayoutXLMProcessor, LayoutLMv2ForTokenClassification, ) class PositionalEncoding(nn.Module): """Positional encoding.""" def __init__(self, num_hiddens, max_len=10000): super(PositionalEncoding, self).__init__() # Create a long enough `P` self.num_hiddens = num_hiddens def forward(self, inputs): max_len = inputs.shape[1] P = torch.zeros((1, max_len, self.num_hiddens)) X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow( 10000, torch.arange(0, self.num_hiddens, 2, dtype=torch.float32) / self.num_hiddens, ) P[:, :, 0::2] = torch.sin(X) P[:, :, 1::2] = torch.cos(X) return P.to(inputs.device) def load_layoutlmv2_custom_model( weight_dir: str, tokenizer_dir: str, max_seq_len: int, classes: list ): model, processor = load_layoutlmv2(tokenizer_dir, weight_dir, max_seq_len, classes) # fix for longer lenght model.layoutlmv2.embeddings.position_embeddings = PositionalEncoding( num_hiddens=768, max_len=max_seq_len ) model.layoutlmv2.embeddings.max_position_embeddings = max_seq_len model.config.max_position_embeddings = max_seq_len model.layoutlmv2.embeddings.register_buffer( "position_ids", torch.arange(max_seq_len).expand((1, -1)) ) return model, processor def load_layoutlmv2( weight_dir: str, tokenizer_dir: str, max_seq_len: int, classes: list ): tokenizer = LayoutXLMTokenizer.from_pretrained( pretrained_model_name_or_path=tokenizer_dir, model_max_length=max_seq_len ) feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False) processor = LayoutXLMProcessor(feature_extractor, tokenizer) model = LayoutLMv2ForTokenClassification.from_pretrained( weight_dir, num_labels=len(classes) ) return model, processor