sbt-idp/cope2n-ai-fi/common/utils_invoice/load_model.py
2023-11-30 18:22:16 +07:00

62 lines
2.0 KiB
Python
Executable File

from torch import nn
import torch
from transformers import (
LayoutXLMTokenizer,
LayoutLMv2FeatureExtractor,
LayoutXLMProcessor,
LayoutLMv2ForTokenClassification,
)
class PositionalEncoding(nn.Module):
"""Positional encoding."""
def __init__(self, num_hiddens, max_len=10000):
super(PositionalEncoding, self).__init__()
# Create a long enough `P`
self.num_hiddens = num_hiddens
def forward(self, inputs):
max_len = inputs.shape[1]
P = torch.zeros((1, max_len, self.num_hiddens))
X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(
10000,
torch.arange(0, self.num_hiddens, 2, dtype=torch.float32)
/ self.num_hiddens,
)
P[:, :, 0::2] = torch.sin(X)
P[:, :, 1::2] = torch.cos(X)
return P.to(inputs.device)
def load_layoutlmv2_custom_model(
weight_dir: str, tokenizer_dir: str, max_seq_len: int, classes: list
):
model, processor = load_layoutlmv2(tokenizer_dir, weight_dir, max_seq_len, classes)
# fix for longer lenght
model.layoutlmv2.embeddings.position_embeddings = PositionalEncoding(
num_hiddens=768, max_len=max_seq_len
)
model.layoutlmv2.embeddings.max_position_embeddings = max_seq_len
model.config.max_position_embeddings = max_seq_len
model.layoutlmv2.embeddings.register_buffer(
"position_ids", torch.arange(max_seq_len).expand((1, -1))
)
return model, processor
def load_layoutlmv2(
weight_dir: str, tokenizer_dir: str, max_seq_len: int, classes: list
):
tokenizer = LayoutXLMTokenizer.from_pretrained(
pretrained_model_name_or_path=tokenizer_dir, model_max_length=max_seq_len
)
feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
processor = LayoutXLMProcessor(feature_extractor, tokenizer)
model = LayoutLMv2ForTokenClassification.from_pretrained(
weight_dir, num_labels=len(classes)
)
return model, processor