62 lines
2.0 KiB
Python
Executable File
62 lines
2.0 KiB
Python
Executable File
from torch import nn
|
|
import torch
|
|
from transformers import (
|
|
LayoutXLMTokenizer,
|
|
LayoutLMv2FeatureExtractor,
|
|
LayoutXLMProcessor,
|
|
LayoutLMv2ForTokenClassification,
|
|
)
|
|
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
"""Positional encoding."""
|
|
|
|
def __init__(self, num_hiddens, max_len=10000):
|
|
super(PositionalEncoding, self).__init__()
|
|
# Create a long enough `P`
|
|
self.num_hiddens = num_hiddens
|
|
|
|
def forward(self, inputs):
|
|
max_len = inputs.shape[1]
|
|
P = torch.zeros((1, max_len, self.num_hiddens))
|
|
X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(
|
|
10000,
|
|
torch.arange(0, self.num_hiddens, 2, dtype=torch.float32)
|
|
/ self.num_hiddens,
|
|
)
|
|
P[:, :, 0::2] = torch.sin(X)
|
|
P[:, :, 1::2] = torch.cos(X)
|
|
return P.to(inputs.device)
|
|
|
|
|
|
def load_layoutlmv2_custom_model(
|
|
weight_dir: str, tokenizer_dir: str, max_seq_len: int, classes: list
|
|
):
|
|
|
|
model, processor = load_layoutlmv2(tokenizer_dir, weight_dir, max_seq_len, classes)
|
|
# fix for longer lenght
|
|
model.layoutlmv2.embeddings.position_embeddings = PositionalEncoding(
|
|
num_hiddens=768, max_len=max_seq_len
|
|
)
|
|
model.layoutlmv2.embeddings.max_position_embeddings = max_seq_len
|
|
model.config.max_position_embeddings = max_seq_len
|
|
model.layoutlmv2.embeddings.register_buffer(
|
|
"position_ids", torch.arange(max_seq_len).expand((1, -1))
|
|
)
|
|
|
|
return model, processor
|
|
|
|
|
|
def load_layoutlmv2(
|
|
weight_dir: str, tokenizer_dir: str, max_seq_len: int, classes: list
|
|
):
|
|
tokenizer = LayoutXLMTokenizer.from_pretrained(
|
|
pretrained_model_name_or_path=tokenizer_dir, model_max_length=max_seq_len
|
|
)
|
|
feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
|
|
processor = LayoutXLMProcessor(feature_extractor, tokenizer)
|
|
model = LayoutLMv2ForTokenClassification.from_pretrained(
|
|
weight_dir, num_labels=len(classes)
|
|
)
|
|
return model, processor
|