278 lines
9.5 KiB
Python
278 lines
9.5 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
import math
|
||
|
import numpy as np
|
||
|
from .encoder import MultiHeadAttention
|
||
|
|
||
|
|
||
|
class PositionwiseFeedForward(nn.Module):
|
||
|
"""Two-layer feed-forward module.
|
||
|
|
||
|
Args:
|
||
|
d_in (int): The dimension of the input for feedforward
|
||
|
network model.
|
||
|
d_hid (int): The dimension of the feedforward
|
||
|
network model.
|
||
|
dropout (float): Dropout layer on feedforward output.
|
||
|
act_cfg (dict): Activation cfg for feedforward module.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, d_in, d_hid, dropout=0.1):
|
||
|
super().__init__()
|
||
|
self.w_1 = nn.Linear(d_in, d_hid)
|
||
|
self.w_2 = nn.Linear(d_hid, d_in)
|
||
|
self.act = nn.GELU()
|
||
|
self.dropout = nn.Dropout(dropout)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.w_1(x)
|
||
|
x = self.act(x)
|
||
|
x = self.w_2(x)
|
||
|
x = self.dropout(x)
|
||
|
|
||
|
return x
|
||
|
|
||
|
|
||
|
class PositionalEncoding(nn.Module):
|
||
|
"""Fixed positional encoding with sine and cosine functions."""
|
||
|
|
||
|
def __init__(self, d_hid=512, n_position=200):
|
||
|
super().__init__()
|
||
|
|
||
|
# Not a parameter
|
||
|
# Position table of shape (1, n_position, d_hid)
|
||
|
self.register_buffer(
|
||
|
'position_table',
|
||
|
self._get_sinusoid_encoding_table(n_position, d_hid))
|
||
|
|
||
|
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
||
|
"""Sinusoid position encoding table."""
|
||
|
denominator = torch.Tensor([
|
||
|
1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
||
|
for hid_j in range(d_hid)
|
||
|
])
|
||
|
denominator = denominator.view(1, -1)
|
||
|
pos_tensor = torch.arange(n_position).unsqueeze(-1).float()
|
||
|
sinusoid_table = pos_tensor * denominator
|
||
|
sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2])
|
||
|
sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2])
|
||
|
|
||
|
return sinusoid_table.unsqueeze(0)
|
||
|
|
||
|
def forward(self, x):
|
||
|
"""
|
||
|
Args:
|
||
|
x (Tensor): Tensor of shape (batch_size, pos_len, d_hid, ...)
|
||
|
"""
|
||
|
self.device = x.device
|
||
|
x = x + self.position_table[:, :x.size(1)].clone().detach()
|
||
|
return x
|
||
|
|
||
|
|
||
|
|
||
|
class TFDecoderLayer(nn.Module):
|
||
|
"""Transformer Decoder Layer.
|
||
|
|
||
|
Args:
|
||
|
d_model (int): The number of expected features
|
||
|
in the decoder inputs (default=512).
|
||
|
d_inner (int): The dimension of the feedforward
|
||
|
network model (default=256).
|
||
|
n_head (int): The number of heads in the
|
||
|
multiheadattention models (default=8).
|
||
|
d_k (int): Total number of features in key.
|
||
|
d_v (int): Total number of features in value.
|
||
|
dropout (float): Dropout layer on attn_output_weights.
|
||
|
qkv_bias (bool): Add bias in projection layer. Default: False.
|
||
|
act_cfg (dict): Activation cfg for feedforward module.
|
||
|
operation_order (tuple[str]): The execution order of operation
|
||
|
in transformer. Such as ('self_attn', 'norm', 'enc_dec_attn',
|
||
|
'norm', 'ffn', 'norm') or ('norm', 'self_attn', 'norm',
|
||
|
'enc_dec_attn', 'norm', 'ffn').
|
||
|
Default:None.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
d_model=512,
|
||
|
d_inner=256,
|
||
|
n_head=8,
|
||
|
d_k=64,
|
||
|
d_v=64,
|
||
|
dropout=0.1,
|
||
|
qkv_bias=False):
|
||
|
super().__init__()
|
||
|
|
||
|
self.norm1 = nn.LayerNorm(d_model)
|
||
|
self.norm2 = nn.LayerNorm(d_model)
|
||
|
self.norm3 = nn.LayerNorm(d_model)
|
||
|
|
||
|
self.self_attn = MultiHeadAttention(
|
||
|
n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias)
|
||
|
|
||
|
self.enc_attn = MultiHeadAttention(
|
||
|
n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias)
|
||
|
|
||
|
self.mlp = PositionwiseFeedForward(
|
||
|
d_model, d_inner, dropout=dropout)
|
||
|
|
||
|
def forward(self,
|
||
|
dec_input,
|
||
|
enc_output,
|
||
|
self_attn_mask=None,
|
||
|
dec_enc_attn_mask=None):
|
||
|
dec_input_norm = self.norm1(dec_input)
|
||
|
dec_attn_out = self.self_attn(dec_input_norm, dec_input_norm,
|
||
|
dec_input_norm, self_attn_mask)
|
||
|
dec_attn_out += dec_input
|
||
|
|
||
|
enc_dec_attn_in = self.norm2(dec_attn_out)
|
||
|
enc_dec_attn_out = self.enc_attn(enc_dec_attn_in, enc_output,
|
||
|
enc_output, dec_enc_attn_mask)
|
||
|
enc_dec_attn_out += dec_attn_out
|
||
|
|
||
|
mlp_out = self.mlp(self.norm3(enc_dec_attn_out))
|
||
|
mlp_out += enc_dec_attn_out
|
||
|
|
||
|
return mlp_out
|
||
|
|
||
|
class NRTRDecoder(nn.Module):
|
||
|
"""Transformer Decoder block with self attention mechanism.
|
||
|
|
||
|
Args:
|
||
|
n_layers (int): Number of attention layers.
|
||
|
d_embedding (int): Language embedding dimension.
|
||
|
n_head (int): Number of parallel attention heads.
|
||
|
d_k (int): Dimension of the key vector.
|
||
|
d_v (int): Dimension of the value vector.
|
||
|
d_model (int): Dimension :math:`D_m` of the input from previous model.
|
||
|
d_inner (int): Hidden dimension of feedforward layers.
|
||
|
n_position (int): Length of the positional encoding vector. Must be
|
||
|
greater than ``max_seq_len``.
|
||
|
dropout (float): Dropout rate.
|
||
|
num_classes (int): Number of output classes :math:`C`.
|
||
|
max_seq_len (int): Maximum output sequence length :math:`T`.
|
||
|
start_idx (int): The index of `<SOS>`.
|
||
|
padding_idx (int): The index of `<PAD>`.
|
||
|
init_cfg (dict or list[dict], optional): Initialization configs.
|
||
|
|
||
|
Warning:
|
||
|
This decoder will not predict the final class which is assumed to be
|
||
|
`<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
|
||
|
is also ignored by loss as specified in
|
||
|
:obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
n_layers=6,
|
||
|
d_embedding=512,
|
||
|
n_head=8,
|
||
|
d_k=64,
|
||
|
d_v=64,
|
||
|
d_model=512,
|
||
|
d_inner=256,
|
||
|
n_position=200,
|
||
|
dropout=0.1,
|
||
|
num_classes=93,
|
||
|
max_seq_len=40,
|
||
|
start_idx=1,
|
||
|
padding_idx=92,
|
||
|
return_confident=False,
|
||
|
end_idx=None,
|
||
|
**kwargs):
|
||
|
super().__init__()
|
||
|
|
||
|
self.padding_idx = padding_idx
|
||
|
self.start_idx = start_idx
|
||
|
self.max_seq_len = max_seq_len
|
||
|
|
||
|
self.trg_word_emb = nn.Embedding(
|
||
|
num_classes, d_embedding, padding_idx=padding_idx)
|
||
|
|
||
|
self.position_enc = PositionalEncoding(
|
||
|
d_embedding, n_position=n_position)
|
||
|
|
||
|
self.layer_stack = nn.ModuleList([
|
||
|
TFDecoderLayer(
|
||
|
d_model, d_inner, n_head, d_k, d_v, dropout=dropout, **kwargs)
|
||
|
for _ in range(n_layers)
|
||
|
])
|
||
|
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
||
|
|
||
|
pred_num_class = num_classes - 1 # ignore padding_idx
|
||
|
self.classifier = nn.Linear(d_model, pred_num_class)
|
||
|
|
||
|
self.return_confident = return_confident
|
||
|
self.end_idx = end_idx
|
||
|
|
||
|
@staticmethod
|
||
|
def get_pad_mask(seq, pad_idx):
|
||
|
|
||
|
return (seq != pad_idx).unsqueeze(-2)
|
||
|
|
||
|
@staticmethod
|
||
|
def get_subsequent_mask(seq):
|
||
|
"""For masking out the subsequent info."""
|
||
|
len_s = seq.size(1)
|
||
|
subsequent_mask = 1 - torch.triu(
|
||
|
torch.ones((len_s, len_s), device=seq.device), diagonal=1)
|
||
|
subsequent_mask = subsequent_mask.unsqueeze(0).bool()
|
||
|
|
||
|
return subsequent_mask
|
||
|
|
||
|
def _attention(self, trg_seq, src, src_mask=None):
|
||
|
trg_embedding = self.trg_word_emb(trg_seq)
|
||
|
trg_pos_encoded = self.position_enc(trg_embedding)
|
||
|
trg_mask = self.get_pad_mask(
|
||
|
trg_seq,
|
||
|
pad_idx=self.padding_idx) & self.get_subsequent_mask(trg_seq)
|
||
|
output = trg_pos_encoded
|
||
|
for dec_layer in self.layer_stack:
|
||
|
output = dec_layer(
|
||
|
output,
|
||
|
src,
|
||
|
self_attn_mask=trg_mask,
|
||
|
dec_enc_attn_mask=src_mask)
|
||
|
output = self.layer_norm(output)
|
||
|
|
||
|
return output
|
||
|
|
||
|
def _get_mask(self, logit):
|
||
|
N, T, _ = logit.size()
|
||
|
mask = logit.new_ones((N, T))
|
||
|
|
||
|
return mask
|
||
|
|
||
|
def forward(self, out_enc):
|
||
|
src_mask = self._get_mask(out_enc)
|
||
|
N = out_enc.size(0)
|
||
|
init_target_seq = torch.full((N, self.max_seq_len + 1),
|
||
|
self.padding_idx,
|
||
|
device=out_enc.device,
|
||
|
dtype=torch.long)
|
||
|
# bsz * seq_len
|
||
|
init_target_seq[:, 0] = self.start_idx
|
||
|
|
||
|
outputs = []
|
||
|
for step in range(0, self.max_seq_len):
|
||
|
decoder_output = self._attention(
|
||
|
trg_seq=init_target_seq,
|
||
|
src=out_enc,
|
||
|
src_mask=src_mask)
|
||
|
if self.return_confident:
|
||
|
step_result = torch.softmax(self.classifier(decoder_output[:, step, :]), -1)
|
||
|
next_step_init = step_result.argmax(-1)
|
||
|
init_target_seq[:, step + 1] = next_step_init
|
||
|
if next_step_init.min() >= self.end_idx:
|
||
|
break
|
||
|
else:
|
||
|
step_result = self.classifier(decoder_output[:, step, :]).argmax(-1)
|
||
|
init_target_seq[:, step + 1] = step_result
|
||
|
if step_result.min() >= self.end_idx:
|
||
|
break
|
||
|
|
||
|
outputs.append(step_result)
|
||
|
|
||
|
outputs = torch.stack(outputs, dim=1)
|
||
|
|
||
|
return outputs
|