317 lines
9.9 KiB
Python
317 lines
9.9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
import math
|
|
from .conv import ConvModule
|
|
|
|
|
|
class LocalityAwareFeedforward(nn.Module):
|
|
"""Locality-aware feedforward layer in SATRN, see `SATRN.
|
|
|
|
<https://arxiv.org/abs/1910.04396>`_
|
|
"""
|
|
|
|
def __init__(self,
|
|
d_in,
|
|
d_hid,
|
|
dropout=0.1):
|
|
super().__init__()
|
|
self.conv1 = ConvModule(
|
|
d_in,
|
|
d_hid,
|
|
kernel_size=1,
|
|
padding=0,
|
|
bias=False,
|
|
norm_cfg=dict(type='BN'),
|
|
act_cfg=dict(type='ReLU'))
|
|
|
|
self.depthwise_conv = ConvModule(
|
|
d_hid,
|
|
d_hid,
|
|
kernel_size=3,
|
|
padding=1,
|
|
bias=False,
|
|
groups=d_hid,
|
|
norm_cfg=dict(type='BN'),
|
|
act_cfg=dict(type='ReLU'))
|
|
|
|
self.conv2 = ConvModule(
|
|
d_hid,
|
|
d_in,
|
|
kernel_size=1,
|
|
padding=0,
|
|
bias=False,
|
|
norm_cfg=dict(type='BN'),
|
|
act_cfg=dict(type='ReLU'))
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.depthwise_conv(x)
|
|
x = self.conv2(x)
|
|
|
|
return x
|
|
|
|
|
|
class ScaledDotProductAttention(nn.Module):
|
|
"""Scaled Dot-Product Attention Module. This code is adopted from
|
|
https://github.com/jadore801120/attention-is-all-you-need-pytorch.
|
|
|
|
Args:
|
|
temperature (float): The scale factor for softmax input.
|
|
attn_dropout (float): Dropout layer on attn_output_weights.
|
|
"""
|
|
|
|
def __init__(self, temperature, attn_dropout=0.1):
|
|
super().__init__()
|
|
self.temperature = temperature
|
|
self.dropout = nn.Dropout(attn_dropout)
|
|
|
|
def forward(self, q, k, v, mask=None):
|
|
|
|
attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
|
|
|
|
if mask is not None:
|
|
attn = attn.masked_fill(mask == 0, float('-inf'))
|
|
|
|
attn = self.dropout(F.softmax(attn, dim=-1))
|
|
output = torch.matmul(attn, v)
|
|
|
|
return output, attn
|
|
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
"""Multi-Head Attention module.
|
|
|
|
Args:
|
|
n_head (int): The number of heads in the
|
|
multiheadattention models (default=8).
|
|
d_model (int): The number of expected features
|
|
in the decoder inputs (default=512).
|
|
d_k (int): Total number of features in key.
|
|
d_v (int): Total number of features in value.
|
|
dropout (float): Dropout layer on attn_output_weights.
|
|
qkv_bias (bool): Add bias in projection layer. Default: False.
|
|
"""
|
|
|
|
def __init__(self,
|
|
n_head=8,
|
|
d_model=512,
|
|
d_k=64,
|
|
d_v=64,
|
|
dropout=0.1,
|
|
qkv_bias=False):
|
|
super().__init__()
|
|
self.n_head = n_head
|
|
self.d_k = d_k
|
|
self.d_v = d_v
|
|
|
|
self.dim_k = n_head * d_k
|
|
self.dim_v = n_head * d_v
|
|
|
|
self.linear_q = nn.Linear(self.dim_k, self.dim_k, bias=qkv_bias)
|
|
self.linear_k = nn.Linear(self.dim_k, self.dim_k, bias=qkv_bias)
|
|
self.linear_v = nn.Linear(self.dim_v, self.dim_v, bias=qkv_bias)
|
|
|
|
self.attention = ScaledDotProductAttention(d_k**0.5, dropout)
|
|
|
|
self.fc = nn.Linear(self.dim_v, d_model, bias=qkv_bias)
|
|
self.proj_drop = nn.Dropout(dropout)
|
|
|
|
def forward(self, q, k, v, mask=None):
|
|
batch_size, len_q, _ = q.size()
|
|
_, len_k, _ = k.size()
|
|
|
|
q = self.linear_q(q).view(batch_size, len_q, self.n_head, self.d_k)
|
|
k = self.linear_k(k).view(batch_size, len_k, self.n_head, self.d_k)
|
|
v = self.linear_v(v).view(batch_size, len_k, self.n_head, self.d_v)
|
|
|
|
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
|
|
|
if mask is not None:
|
|
if mask.dim() == 3:
|
|
mask = mask.unsqueeze(1)
|
|
elif mask.dim() == 2:
|
|
mask = mask.unsqueeze(1).unsqueeze(1)
|
|
|
|
attn_out, _ = self.attention(q, k, v, mask=mask)
|
|
|
|
attn_out = attn_out.transpose(1, 2).contiguous().view(
|
|
batch_size, len_q, self.dim_v)
|
|
|
|
attn_out = self.fc(attn_out)
|
|
attn_out = self.proj_drop(attn_out)
|
|
|
|
return attn_out
|
|
|
|
|
|
class SatrnEncoderLayer(nn.Module):
|
|
""""""
|
|
|
|
def __init__(self,
|
|
d_model=512,
|
|
d_inner=512,
|
|
n_head=8,
|
|
d_k=64,
|
|
d_v=64,
|
|
dropout=0.1,
|
|
qkv_bias=False):
|
|
super().__init__()
|
|
self.norm1 = nn.LayerNorm(d_model)
|
|
self.attn = MultiHeadAttention(
|
|
n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout)
|
|
self.norm2 = nn.LayerNorm(d_model)
|
|
self.feed_forward = LocalityAwareFeedforward(
|
|
d_model, d_inner, dropout=dropout)
|
|
|
|
def forward(self, x, h, w, mask=None):
|
|
n, hw, c = x.size()
|
|
residual = x
|
|
x = self.norm1(x)
|
|
x = residual + self.attn(x, x, x, mask)
|
|
residual = x
|
|
x = self.norm2(x)
|
|
x = x.transpose(1, 2).contiguous().view(n, c, h, w)
|
|
x = self.feed_forward(x)
|
|
x = x.view(n, c, hw).transpose(1, 2)
|
|
x = residual + x
|
|
return x
|
|
|
|
|
|
class Adaptive2DPositionalEncoding(nn.Module):
|
|
"""Implement Adaptive 2D positional encoder for SATRN, see
|
|
`SATRN <https://arxiv.org/abs/1910.04396>`_
|
|
Modified from https://github.com/Media-Smart/vedastr
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
Args:
|
|
d_hid (int): Dimensions of hidden layer.
|
|
n_height (int): Max height of the 2D feature output.
|
|
n_width (int): Max width of the 2D feature output.
|
|
dropout (int): Size of hidden layers of the model.
|
|
"""
|
|
|
|
def __init__(self,
|
|
d_hid=512,
|
|
n_height=100,
|
|
n_width=100,
|
|
dropout=0.1):
|
|
super().__init__()
|
|
|
|
h_position_encoder = self._get_sinusoid_encoding_table(n_height, d_hid)
|
|
h_position_encoder = h_position_encoder.transpose(0, 1)
|
|
h_position_encoder = h_position_encoder.view(1, d_hid, n_height, 1)
|
|
|
|
w_position_encoder = self._get_sinusoid_encoding_table(n_width, d_hid)
|
|
w_position_encoder = w_position_encoder.transpose(0, 1)
|
|
w_position_encoder = w_position_encoder.view(1, d_hid, 1, n_width)
|
|
|
|
self.register_buffer('h_position_encoder', h_position_encoder)
|
|
self.register_buffer('w_position_encoder', w_position_encoder)
|
|
|
|
self.h_scale = self.scale_factor_generate(d_hid)
|
|
self.w_scale = self.scale_factor_generate(d_hid)
|
|
self.pool = nn.AdaptiveAvgPool2d(1)
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
|
|
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
|
"""Sinusoid position encoding table."""
|
|
denominator = torch.Tensor([
|
|
1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
|
for hid_j in range(d_hid)
|
|
])
|
|
denominator = denominator.view(1, -1)
|
|
pos_tensor = torch.arange(n_position).unsqueeze(-1).float()
|
|
sinusoid_table = pos_tensor * denominator
|
|
sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2])
|
|
sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2])
|
|
|
|
return sinusoid_table
|
|
|
|
def scale_factor_generate(self, d_hid):
|
|
scale_factor = nn.Sequential(
|
|
nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.ReLU(inplace=True),
|
|
nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.Sigmoid())
|
|
|
|
return scale_factor
|
|
|
|
def forward(self, x):
|
|
b, c, h, w = x.size()
|
|
|
|
avg_pool = self.pool(x)
|
|
|
|
h_pos_encoding = \
|
|
self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :]
|
|
w_pos_encoding = \
|
|
self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w]
|
|
|
|
out = x + h_pos_encoding + w_pos_encoding
|
|
|
|
out = self.dropout(out)
|
|
|
|
return out
|
|
|
|
|
|
class SatrnEncoder(nn.Module):
|
|
"""Implement encoder for SATRN, see `SATRN.
|
|
|
|
<https://arxiv.org/abs/1910.04396>`_.
|
|
|
|
Args:
|
|
n_layers (int): Number of attention layers.
|
|
n_head (int): Number of parallel attention heads.
|
|
d_k (int): Dimension of the key vector.
|
|
d_v (int): Dimension of the value vector.
|
|
d_model (int): Dimension :math:`D_m` of the input from previous model.
|
|
n_position (int): Length of the positional encoding vector. Must be
|
|
greater than ``max_seq_len``.
|
|
d_inner (int): Hidden dimension of feedforward layers.
|
|
dropout (float): Dropout rate.
|
|
init_cfg (dict or list[dict], optional): Initialization configs.
|
|
"""
|
|
|
|
def __init__(self,
|
|
n_layers=12,
|
|
n_head=8,
|
|
d_k=64,
|
|
d_v=64,
|
|
d_model=512,
|
|
n_position=100,
|
|
d_inner=256,
|
|
dropout=0.1,
|
|
**kwargs):
|
|
super().__init__()
|
|
self.d_model = d_model
|
|
self.position_enc = Adaptive2DPositionalEncoding(
|
|
d_hid=d_model,
|
|
n_height=n_position,
|
|
n_width=n_position,
|
|
dropout=dropout)
|
|
self.layer_stack = nn.ModuleList([
|
|
SatrnEncoderLayer(
|
|
d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
|
|
for _ in range(n_layers)
|
|
])
|
|
self.layer_norm = nn.LayerNorm(d_model)
|
|
|
|
def forward(self, feat):
|
|
"""
|
|
Args:
|
|
feat (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`.
|
|
|
|
Returns:
|
|
Tensor: A tensor of shape :math:`(N, T, D_m)`.
|
|
"""
|
|
feat = feat + self.position_enc(feat)
|
|
n, c, h, w = feat.size()
|
|
mask = feat.new_zeros((n, h, w))
|
|
mask[:,:,:w] = 1
|
|
|
|
mask = mask.view(n, h * w)
|
|
feat = feat.view(n, c, h * w)
|
|
|
|
output = feat.permute(0, 2, 1).contiguous()
|
|
for enc_layer in self.layer_stack:
|
|
output = enc_layer(output, h, w, mask)
|
|
output = self.layer_norm(output)
|
|
|
|
return output |