395 lines
14 KiB
Python
Executable File
395 lines
14 KiB
Python
Executable File
import torch
|
|
import torch.nn as nn
|
|
from mmcv.cnn import ConvModule
|
|
from torch.nn.modules.batchnorm import _BatchNorm
|
|
|
|
|
|
class DarknetBottleneck(nn.Module):
|
|
"""The basic bottleneck block used in Darknet.
|
|
|
|
Each ResBlock consists of two ConvModules and the input is added to the
|
|
final output. Each ConvModule is composed of Conv, BN, and LeakyReLU.
|
|
The first convLayer has filter size of 1x1 and the second one has the
|
|
filter size of 3x3.
|
|
|
|
Args:
|
|
in_channels (int): The input channels of this Module.
|
|
out_channels (int): The output channels of this Module.
|
|
expansion (int): The kernel size of the convolution. Default: 0.5
|
|
add_identity (bool): Whether to add identity to the out.
|
|
Default: True
|
|
use_depthwise (bool): Whether to use depthwise separable convolution.
|
|
Default: False
|
|
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
|
which means using conv2d.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
act_cfg (dict): Config dict for activation layer.
|
|
Default: dict(type='Swish').
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
expansion=0.5,
|
|
add_identity=True,
|
|
use_depthwise=False,
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
|
act_cfg=dict(type='Swish')):
|
|
super().__init__()
|
|
hidden_channels = int(out_channels * expansion)
|
|
self.conv1 = ConvModule(
|
|
in_channels,
|
|
hidden_channels,
|
|
1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
self.conv2 = ConvModule(
|
|
hidden_channels,
|
|
out_channels,
|
|
3,
|
|
stride=1,
|
|
padding=1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
self.add_identity = \
|
|
add_identity and in_channels == out_channels
|
|
|
|
def forward(self, x):
|
|
identity = x
|
|
out = self.conv1(x)
|
|
out = self.conv2(out)
|
|
|
|
if self.add_identity:
|
|
return out + identity
|
|
else:
|
|
return out
|
|
|
|
|
|
class CSPLayer(nn.Module):
|
|
"""Cross Stage Partial Layer.
|
|
|
|
Args:
|
|
in_channels (int): The input channels of the CSP layer.
|
|
out_channels (int): The output channels of the CSP layer.
|
|
expand_ratio (float): Ratio to adjust the number of channels of the
|
|
hidden layer. Default: 0.5
|
|
num_blocks (int): Number of blocks. Default: 1
|
|
add_identity (bool): Whether to add identity in blocks.
|
|
Default: True
|
|
use_depthwise (bool): Whether to depthwise separable convolution in
|
|
blocks. Default: False
|
|
conv_cfg (dict, optional): Config dict for convolution layer.
|
|
Default: None, which means using conv2d.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='BN')
|
|
act_cfg (dict): Config dict for activation layer.
|
|
Default: dict(type='Swish')
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
expand_ratio=0.5,
|
|
num_blocks=1,
|
|
add_identity=True,
|
|
use_depthwise=False,
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
|
act_cfg=dict(type='Swish')):
|
|
super().__init__()
|
|
mid_channels = int(out_channels * expand_ratio)
|
|
self.main_conv = ConvModule(
|
|
in_channels,
|
|
mid_channels,
|
|
1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
self.short_conv = ConvModule(
|
|
in_channels,
|
|
mid_channels,
|
|
1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
self.final_conv = ConvModule(
|
|
2 * mid_channels,
|
|
out_channels,
|
|
1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
|
|
self.blocks = nn.Sequential(*[
|
|
DarknetBottleneck(
|
|
mid_channels,
|
|
mid_channels,
|
|
1.0,
|
|
add_identity,
|
|
use_depthwise,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg) for _ in range(num_blocks)
|
|
])
|
|
|
|
def forward(self, x):
|
|
x_short = self.short_conv(x)
|
|
|
|
x_main = self.main_conv(x)
|
|
x_main = self.blocks(x_main)
|
|
|
|
x_final = torch.cat((x_main, x_short), dim=1)
|
|
return self.final_conv(x_final)
|
|
|
|
|
|
|
|
class Focus(nn.Module):
|
|
"""Focus width and height information into channel space.
|
|
|
|
Args:
|
|
in_channels (int): The input channels of this Module.
|
|
out_channels (int): The output channels of this Module.
|
|
kernel_size (int): The kernel size of the convolution. Default: 1
|
|
stride (int): The stride of the convolution. Default: 1
|
|
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
|
which means using conv2d.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='BN', momentum=0.03, eps=0.001).
|
|
act_cfg (dict): Config dict for activation layer.
|
|
Default: dict(type='Swish').
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
|
act_cfg=dict(type='Swish')):
|
|
super().__init__()
|
|
self.conv = ConvModule(
|
|
in_channels * 4,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding=(kernel_size - 1) // 2,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
|
|
def forward(self, x):
|
|
# shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
|
|
patch_top_left = x[..., ::2, ::2]
|
|
patch_top_right = x[..., ::2, 1::2]
|
|
patch_bot_left = x[..., 1::2, ::2]
|
|
patch_bot_right = x[..., 1::2, 1::2]
|
|
x = torch.cat(
|
|
(
|
|
patch_top_left,
|
|
patch_bot_left,
|
|
patch_top_right,
|
|
patch_bot_right,
|
|
),
|
|
dim=1,
|
|
)
|
|
return self.conv(x)
|
|
|
|
|
|
class SPPBottleneck(nn.Module):
|
|
"""Spatial pyramid pooling layer used in YOLOv3-SPP.
|
|
|
|
Args:
|
|
in_channels (int): The input channels of this Module.
|
|
out_channels (int): The output channels of this Module.
|
|
kernel_sizes (tuple[int]): Sequential of kernel sizes of pooling
|
|
layers. Default: (5, 9, 13).
|
|
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
|
which means using conv2d.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
act_cfg (dict): Config dict for activation layer.
|
|
Default: dict(type='Swish').
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_sizes=(5, 9, 13),
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
|
act_cfg=dict(type='Swish')):
|
|
super().__init__()
|
|
mid_channels = in_channels // 2
|
|
self.conv1 = ConvModule(
|
|
in_channels,
|
|
mid_channels,
|
|
1,
|
|
stride=1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
self.poolings = nn.ModuleList([
|
|
nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2)
|
|
for ks in kernel_sizes
|
|
])
|
|
conv2_channels = mid_channels * (len(kernel_sizes) + 1)
|
|
self.conv2 = ConvModule(
|
|
conv2_channels,
|
|
out_channels,
|
|
1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = torch.cat([x] + [pooling(x) for pooling in self.poolings], dim=1)
|
|
x = self.conv2(x)
|
|
return x
|
|
|
|
|
|
class CSPDarknet(nn.Module):
|
|
"""CSP-Darknet backbone used in YOLOv5 and YOLOX.
|
|
|
|
Args:
|
|
arch (str): Architecture of CSP-Darknet, from {P5, P6}.
|
|
Default: P5.
|
|
deepen_factor (float): Depth multiplier, multiply number of
|
|
blocks in CSP layer by this amount. Default: 1.0.
|
|
widen_factor (float): Width multiplier, multiply number of
|
|
channels in each layer by this amount. Default: 1.0.
|
|
out_indices (Sequence[int]): Output from which stages.
|
|
Default: (2, 3, 4).
|
|
frozen_stages (int): Stages to be frozen (stop grad and set eval
|
|
mode). -1 means not freezing any parameters. Default: -1.
|
|
use_depthwise (bool): Whether to use depthwise separable convolution.
|
|
Default: False.
|
|
arch_ovewrite(list): Overwrite default arch settings. Default: None.
|
|
spp_kernal_sizes: (tuple[int]): Sequential of kernel sizes of SPP
|
|
layers. Default: (5, 9, 13).
|
|
conv_cfg (dict): Config dict for convolution layer. Default: None.
|
|
norm_cfg (dict): Dictionary to construct and config norm layer.
|
|
Default: dict(type='BN', requires_grad=True).
|
|
act_cfg (dict): Config dict for activation layer.
|
|
Default: dict(type='LeakyReLU', negative_slope=0.1).
|
|
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
|
freeze running stats (mean and var). Note: Effect on Batch Norm
|
|
and its variants only.
|
|
"""
|
|
# From left to right:
|
|
# in_channels, out_channels, num_blocks, add_identity, use_spp
|
|
arch_settings = {
|
|
'P5': [[64, 128, 3, True, False], [128, 256, 9, True, False],
|
|
[256, 512, 9, True, False], [512, 1024, 3, False, True]],
|
|
'P6': [[64, 128, 3, True, False], [128, 256, 9, True, False],
|
|
[256, 512, 9, True, False], [512, 768, 3, True, False],
|
|
[768, 1024, 3, False, True]]
|
|
}
|
|
|
|
def __init__(self,
|
|
arch='P5',
|
|
deepen_factor=1.0,
|
|
widen_factor=1.0,
|
|
out_indices=(2, 3, 4),
|
|
frozen_stages=-1,
|
|
use_depthwise=False,
|
|
arch_ovewrite=None,
|
|
spp_kernal_sizes=(5, 9, 13),
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
|
act_cfg=dict(type='Swish'),
|
|
norm_eval=False):
|
|
super().__init__()
|
|
arch_setting = self.arch_settings[arch]
|
|
if arch_ovewrite:
|
|
arch_setting = arch_ovewrite
|
|
assert set(out_indices).issubset(
|
|
i for i in range(len(arch_setting) + 1))
|
|
if frozen_stages not in range(-1, len(arch_setting) + 1):
|
|
raise ValueError('frozen_stages must be in range(-1, '
|
|
'len(arch_setting) + 1). But received '
|
|
f'{frozen_stages}')
|
|
|
|
self.out_indices = out_indices
|
|
self.frozen_stages = frozen_stages
|
|
self.use_depthwise = use_depthwise
|
|
self.norm_eval = norm_eval
|
|
|
|
self.stem = Focus(
|
|
3,
|
|
int(arch_setting[0][0] * widen_factor),
|
|
kernel_size=3,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
self.layers = ['stem']
|
|
|
|
for i, (in_channels, out_channels, num_blocks, add_identity,
|
|
use_spp) in enumerate(arch_setting):
|
|
in_channels = int(in_channels * widen_factor)
|
|
out_channels = int(out_channels * widen_factor)
|
|
num_blocks = max(round(num_blocks * deepen_factor), 1)
|
|
stage = []
|
|
conv_layer = ConvModule(
|
|
in_channels,
|
|
out_channels,
|
|
3,
|
|
stride=2,
|
|
padding=1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
stage.append(conv_layer)
|
|
if use_spp:
|
|
spp = SPPBottleneck(
|
|
out_channels,
|
|
out_channels,
|
|
kernel_sizes=spp_kernal_sizes,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
stage.append(spp)
|
|
csp_layer = CSPLayer(
|
|
out_channels,
|
|
out_channels,
|
|
num_blocks=num_blocks,
|
|
add_identity=add_identity,
|
|
use_depthwise=use_depthwise,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
stage.append(csp_layer)
|
|
self.add_module(f'stage{i + 1}', nn.Sequential(*stage))
|
|
self.layers.append(f'stage{i + 1}')
|
|
|
|
def _freeze_stages(self):
|
|
if self.frozen_stages >= 0:
|
|
for i in range(self.frozen_stages + 1):
|
|
m = getattr(self, self.layers[i])
|
|
m.eval()
|
|
for param in m.parameters():
|
|
param.requires_grad = False
|
|
|
|
def train(self, mode=True):
|
|
super(CSPDarknet, self).train(mode)
|
|
self._freeze_stages()
|
|
if mode and self.norm_eval:
|
|
for m in self.modules():
|
|
if isinstance(m, _BatchNorm):
|
|
m.eval()
|
|
|
|
def forward(self, x):
|
|
outs = []
|
|
for i, layer_name in enumerate(self.layers):
|
|
layer = getattr(self, layer_name)
|
|
x = layer(x)
|
|
if i in self.out_indices:
|
|
outs.append(x)
|
|
return tuple(outs) |