import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math from .conv import ConvModule class LocalityAwareFeedforward(nn.Module): """Locality-aware feedforward layer in SATRN, see `SATRN. `_ """ def __init__(self, d_in, d_hid, dropout=0.1): super().__init__() self.conv1 = ConvModule( d_in, d_hid, kernel_size=1, padding=0, bias=False, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU')) self.depthwise_conv = ConvModule( d_hid, d_hid, kernel_size=3, padding=1, bias=False, groups=d_hid, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU')) self.conv2 = ConvModule( d_hid, d_in, kernel_size=1, padding=0, bias=False, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU')) def forward(self, x): x = self.conv1(x) x = self.depthwise_conv(x) x = self.conv2(x) return x class ScaledDotProductAttention(nn.Module): """Scaled Dot-Product Attention Module. This code is adopted from https://github.com/jadore801120/attention-is-all-you-need-pytorch. Args: temperature (float): The scale factor for softmax input. attn_dropout (float): Dropout layer on attn_output_weights. """ def __init__(self, temperature, attn_dropout=0.1): super().__init__() self.temperature = temperature self.dropout = nn.Dropout(attn_dropout) def forward(self, q, k, v, mask=None): attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) if mask is not None: attn = attn.masked_fill(mask == 0, float('-inf')) attn = self.dropout(F.softmax(attn, dim=-1)) output = torch.matmul(attn, v) return output, attn class MultiHeadAttention(nn.Module): """Multi-Head Attention module. Args: n_head (int): The number of heads in the multiheadattention models (default=8). d_model (int): The number of expected features in the decoder inputs (default=512). d_k (int): Total number of features in key. d_v (int): Total number of features in value. dropout (float): Dropout layer on attn_output_weights. qkv_bias (bool): Add bias in projection layer. Default: False. """ def __init__(self, n_head=8, d_model=512, d_k=64, d_v=64, dropout=0.1, qkv_bias=False): super().__init__() self.n_head = n_head self.d_k = d_k self.d_v = d_v self.dim_k = n_head * d_k self.dim_v = n_head * d_v self.linear_q = nn.Linear(self.dim_k, self.dim_k, bias=qkv_bias) self.linear_k = nn.Linear(self.dim_k, self.dim_k, bias=qkv_bias) self.linear_v = nn.Linear(self.dim_v, self.dim_v, bias=qkv_bias) self.attention = ScaledDotProductAttention(d_k**0.5, dropout) self.fc = nn.Linear(self.dim_v, d_model, bias=qkv_bias) self.proj_drop = nn.Dropout(dropout) def forward(self, q, k, v, mask=None): batch_size, len_q, _ = q.size() _, len_k, _ = k.size() q = self.linear_q(q).view(batch_size, len_q, self.n_head, self.d_k) k = self.linear_k(k).view(batch_size, len_k, self.n_head, self.d_k) v = self.linear_v(v).view(batch_size, len_k, self.n_head, self.d_v) q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) if mask is not None: if mask.dim() == 3: mask = mask.unsqueeze(1) elif mask.dim() == 2: mask = mask.unsqueeze(1).unsqueeze(1) attn_out, _ = self.attention(q, k, v, mask=mask) attn_out = attn_out.transpose(1, 2).contiguous().view( batch_size, len_q, self.dim_v) attn_out = self.fc(attn_out) attn_out = self.proj_drop(attn_out) return attn_out class SatrnEncoderLayer(nn.Module): """""" def __init__(self, d_model=512, d_inner=512, n_head=8, d_k=64, d_v=64, dropout=0.1, qkv_bias=False): super().__init__() self.norm1 = nn.LayerNorm(d_model) self.attn = MultiHeadAttention( n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout) self.norm2 = nn.LayerNorm(d_model) self.feed_forward = LocalityAwareFeedforward( d_model, d_inner, dropout=dropout) def forward(self, x, h, w, mask=None): n, hw, c = x.size() residual = x x = self.norm1(x) x = residual + self.attn(x, x, x, mask) residual = x x = self.norm2(x) x = x.transpose(1, 2).contiguous().view(n, c, h, w) x = self.feed_forward(x) x = x.view(n, c, hw).transpose(1, 2) x = residual + x return x class Adaptive2DPositionalEncoding(nn.Module): """Implement Adaptive 2D positional encoder for SATRN, see `SATRN `_ Modified from https://github.com/Media-Smart/vedastr Licensed under the Apache License, Version 2.0 (the "License"); Args: d_hid (int): Dimensions of hidden layer. n_height (int): Max height of the 2D feature output. n_width (int): Max width of the 2D feature output. dropout (int): Size of hidden layers of the model. """ def __init__(self, d_hid=512, n_height=100, n_width=100, dropout=0.1): super().__init__() h_position_encoder = self._get_sinusoid_encoding_table(n_height, d_hid) h_position_encoder = h_position_encoder.transpose(0, 1) h_position_encoder = h_position_encoder.view(1, d_hid, n_height, 1) w_position_encoder = self._get_sinusoid_encoding_table(n_width, d_hid) w_position_encoder = w_position_encoder.transpose(0, 1) w_position_encoder = w_position_encoder.view(1, d_hid, 1, n_width) self.register_buffer('h_position_encoder', h_position_encoder) self.register_buffer('w_position_encoder', w_position_encoder) self.h_scale = self.scale_factor_generate(d_hid) self.w_scale = self.scale_factor_generate(d_hid) self.pool = nn.AdaptiveAvgPool2d(1) self.dropout = nn.Dropout(p=dropout) def _get_sinusoid_encoding_table(self, n_position, d_hid): """Sinusoid position encoding table.""" denominator = torch.Tensor([ 1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid) ]) denominator = denominator.view(1, -1) pos_tensor = torch.arange(n_position).unsqueeze(-1).float() sinusoid_table = pos_tensor * denominator sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2]) sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2]) return sinusoid_table def scale_factor_generate(self, d_hid): scale_factor = nn.Sequential( nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.ReLU(inplace=True), nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.Sigmoid()) return scale_factor def forward(self, x): b, c, h, w = x.size() avg_pool = self.pool(x) h_pos_encoding = \ self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :] w_pos_encoding = \ self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w] out = x + h_pos_encoding + w_pos_encoding out = self.dropout(out) return out class SatrnEncoder(nn.Module): """Implement encoder for SATRN, see `SATRN. `_. Args: n_layers (int): Number of attention layers. n_head (int): Number of parallel attention heads. d_k (int): Dimension of the key vector. d_v (int): Dimension of the value vector. d_model (int): Dimension :math:`D_m` of the input from previous model. n_position (int): Length of the positional encoding vector. Must be greater than ``max_seq_len``. d_inner (int): Hidden dimension of feedforward layers. dropout (float): Dropout rate. init_cfg (dict or list[dict], optional): Initialization configs. """ def __init__(self, n_layers=12, n_head=8, d_k=64, d_v=64, d_model=512, n_position=100, d_inner=256, dropout=0.1, **kwargs): super().__init__() self.d_model = d_model self.position_enc = Adaptive2DPositionalEncoding( d_hid=d_model, n_height=n_position, n_width=n_position, dropout=dropout) self.layer_stack = nn.ModuleList([ SatrnEncoderLayer( d_model, d_inner, n_head, d_k, d_v, dropout=dropout) for _ in range(n_layers) ]) self.layer_norm = nn.LayerNorm(d_model) def forward(self, feat): """ Args: feat (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`. Returns: Tensor: A tensor of shape :math:`(N, T, D_m)`. """ feat = feat + self.position_enc(feat) n, c, h, w = feat.size() mask = feat.new_zeros((n, h, w)) mask[:,:,:w] = 1 mask = mask.view(n, h * w) feat = feat.view(n, c, h * w) output = feat.permute(0, 2, 1).contiguous() for enc_layer in self.layer_stack: output = enc_layer(output, h, w, mask) output = self.layer_norm(output) return output