173 lines
7.2 KiB
Python
Executable File
173 lines
7.2 KiB
Python
Executable File
import warnings
|
|
import torch.nn as nn
|
|
from torch.nn.modules.instancenorm import _InstanceNorm
|
|
from torch.nn.modules.batchnorm import _BatchNorm
|
|
from mmcv.cnn import build_padding_layer, build_conv_layer, build_norm_layer, build_activation_layer
|
|
|
|
class ConvModule(nn.Module):
|
|
"""A conv block that bundles conv/norm/activation layers.
|
|
|
|
This block simplifies the usage of convolution layers, which are commonly
|
|
used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
|
|
It is based upon three build methods: `build_conv_layer()`,
|
|
`build_norm_layer()` and `build_activation_layer()`.
|
|
|
|
Besides, we add some additional features in this module.
|
|
1. Automatically set `bias` of the conv layer.
|
|
2. Spectral norm is supported.
|
|
3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
|
|
supports zero and circular padding, and we add "reflect" padding mode.
|
|
|
|
Args:
|
|
in_channels (int): Number of channels in the input feature map.
|
|
Same as that in ``nn._ConvNd``.
|
|
out_channels (int): Number of channels produced by the convolution.
|
|
Same as that in ``nn._ConvNd``.
|
|
kernel_size (int | tuple[int]): Size of the convolving kernel.
|
|
Same as that in ``nn._ConvNd``.
|
|
stride (int | tuple[int]): Stride of the convolution.
|
|
Same as that in ``nn._ConvNd``.
|
|
padding (int | tuple[int]): Zero-padding added to both sides of
|
|
the input. Same as that in ``nn._ConvNd``.
|
|
dilation (int | tuple[int]): Spacing between kernel elements.
|
|
Same as that in ``nn._ConvNd``.
|
|
groups (int): Number of blocked connections from input channels to
|
|
output channels. Same as that in ``nn._ConvNd``.
|
|
bias (bool | str): If specified as `auto`, it will be decided by the
|
|
norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
|
|
False. Default: "auto".
|
|
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
|
which means using conv2d.
|
|
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
|
act_cfg (dict): Config dict for activation layer.
|
|
Default: dict(type='ReLU').
|
|
inplace (bool): Whether to use inplace mode for activation.
|
|
Default: True.
|
|
with_spectral_norm (bool): Whether use spectral norm in conv module.
|
|
Default: False.
|
|
padding_mode (str): If the `padding_mode` has not been supported by
|
|
current `Conv2d` in PyTorch, we will use our own padding layer
|
|
instead. Currently, we support ['zeros', 'circular'] with official
|
|
implementation and ['reflect'] with our own implementation.
|
|
Default: 'zeros'.
|
|
order (tuple[str]): The order of conv/norm/activation layers. It is a
|
|
sequence of "conv", "norm" and "act". Common examples are
|
|
("conv", "norm", "act") and ("act", "conv", "norm").
|
|
Default: ('conv', 'norm', 'act').
|
|
"""
|
|
|
|
_abbr_ = 'conv_block'
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
groups=1,
|
|
bias='auto',
|
|
conv_cfg=None,
|
|
norm_cfg=None,
|
|
act_cfg=dict(type='ReLU'),
|
|
inplace=True,
|
|
with_spectral_norm=False,
|
|
padding_mode='zeros',
|
|
order=('conv', 'norm', 'act')):
|
|
super(ConvModule, self).__init__()
|
|
assert conv_cfg is None or isinstance(conv_cfg, dict)
|
|
assert norm_cfg is None or isinstance(norm_cfg, dict)
|
|
assert act_cfg is None or isinstance(act_cfg, dict)
|
|
official_padding_mode = ['zeros', 'circular']
|
|
self.conv_cfg = conv_cfg
|
|
self.norm_cfg = norm_cfg
|
|
self.act_cfg = act_cfg
|
|
self.inplace = inplace
|
|
self.with_spectral_norm = with_spectral_norm
|
|
self.with_explicit_padding = padding_mode not in official_padding_mode
|
|
self.order = order
|
|
assert isinstance(self.order, tuple) and len(self.order) == 3
|
|
assert set(order) == set(['conv', 'norm', 'act'])
|
|
|
|
self.with_norm = norm_cfg is not None
|
|
self.with_activation = act_cfg is not None
|
|
# if the conv layer is before a norm layer, bias is unnecessary.
|
|
if bias == 'auto':
|
|
bias = not self.with_norm
|
|
self.with_bias = bias
|
|
|
|
if self.with_explicit_padding:
|
|
pad_cfg = dict(type=padding_mode)
|
|
self.padding_layer = build_padding_layer(pad_cfg, padding)
|
|
|
|
# reset padding to 0 for conv module
|
|
conv_padding = 0 if self.with_explicit_padding else padding
|
|
# build convolution layer
|
|
self.conv = build_conv_layer(
|
|
conv_cfg,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=stride,
|
|
padding=conv_padding,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
bias=bias)
|
|
# export the attributes of self.conv to a higher level for convenience
|
|
self.in_channels = self.conv.in_channels
|
|
self.out_channels = self.conv.out_channels
|
|
self.kernel_size = self.conv.kernel_size
|
|
self.stride = self.conv.stride
|
|
self.padding = padding
|
|
self.dilation = self.conv.dilation
|
|
self.transposed = self.conv.transposed
|
|
self.output_padding = self.conv.output_padding
|
|
self.groups = self.conv.groups
|
|
|
|
if self.with_spectral_norm:
|
|
self.conv = nn.utils.spectral_norm(self.conv)
|
|
|
|
# build normalization layers
|
|
if self.with_norm:
|
|
# norm layer is after conv layer
|
|
if order.index('norm') > order.index('conv'):
|
|
norm_channels = out_channels
|
|
else:
|
|
norm_channels = in_channels
|
|
self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
|
|
self.add_module(self.norm_name, norm)
|
|
if self.with_bias:
|
|
if isinstance(norm, (_BatchNorm, _InstanceNorm)):
|
|
warnings.warn(
|
|
'Unnecessary conv bias before batch/instance norm')
|
|
else:
|
|
self.norm_name = None
|
|
|
|
# build activation layer
|
|
if self.with_activation:
|
|
act_cfg_ = act_cfg.copy()
|
|
# nn.Tanh has no 'inplace' argument
|
|
if act_cfg_['type'] not in [
|
|
'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish'
|
|
]:
|
|
act_cfg_.setdefault('inplace', inplace)
|
|
self.activate = build_activation_layer(act_cfg_)
|
|
|
|
@property
|
|
def norm(self):
|
|
if self.norm_name:
|
|
return getattr(self, self.norm_name)
|
|
else:
|
|
return None
|
|
|
|
def forward(self, x, activate=True, norm=True):
|
|
for layer in self.order:
|
|
if layer == 'conv':
|
|
if self.with_explicit_padding:
|
|
x = self.padding_layer(x)
|
|
x = self.conv(x)
|
|
elif layer == 'norm' and norm and self.with_norm:
|
|
x = self.norm(x)
|
|
elif layer == 'act' and activate and self.with_activation:
|
|
x = self.activate(x)
|
|
return x |