225 lines
9.1 KiB
Python
225 lines
9.1 KiB
Python
|
import torch
|
||
|
import numpy as np
|
||
|
from torch.nn.modules.utils import _pair
|
||
|
|
||
|
|
||
|
class MlvlPointGenerator:
|
||
|
"""Standard points generator for multi-level (Mlvl) feature maps in 2D
|
||
|
points-based detectors.
|
||
|
|
||
|
Args:
|
||
|
strides (list[int] | list[tuple[int, int]]): Strides of anchors
|
||
|
in multiple feature levels in order (w, h).
|
||
|
offset (float): The offset of points, the value is normalized with
|
||
|
corresponding stride. Defaults to 0.5.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, strides, offset=0.5):
|
||
|
self.strides = [_pair(stride) for stride in strides]
|
||
|
self.offset = offset
|
||
|
|
||
|
@property
|
||
|
def num_levels(self):
|
||
|
"""int: number of feature levels that the generator will be applied"""
|
||
|
return len(self.strides)
|
||
|
|
||
|
@property
|
||
|
def num_base_priors(self):
|
||
|
"""list[int]: The number of priors (points) at a point
|
||
|
on the feature grid"""
|
||
|
return [1 for _ in range(len(self.strides))]
|
||
|
|
||
|
def _meshgrid(self, x, y, row_major=True):
|
||
|
yy, xx = torch.meshgrid(y, x)
|
||
|
if row_major:
|
||
|
# warning .flatten() would cause error in ONNX exporting
|
||
|
# have to use reshape here
|
||
|
return xx.flatten(), yy.flatten()
|
||
|
|
||
|
else:
|
||
|
return yy.flatten(), xx.flatten()
|
||
|
|
||
|
def grid_priors(self,
|
||
|
featmap_sizes,
|
||
|
dtype=torch.float32,
|
||
|
device='cuda',
|
||
|
with_stride=False):
|
||
|
"""Generate grid points of multiple feature levels.
|
||
|
|
||
|
Args:
|
||
|
featmap_sizes (list[tuple]): List of feature map sizes in
|
||
|
multiple feature levels, each size arrange as
|
||
|
as (h, w).
|
||
|
dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32.
|
||
|
device (str): The device where the anchors will be put on.
|
||
|
with_stride (bool): Whether to concatenate the stride to
|
||
|
the last dimension of points.
|
||
|
|
||
|
Return:
|
||
|
list[torch.Tensor]: Points of multiple feature levels.
|
||
|
The sizes of each tensor should be (N, 2) when with stride is
|
||
|
``False``, where N = width * height, width and height
|
||
|
are the sizes of the corresponding feature level,
|
||
|
and the last dimension 2 represent (coord_x, coord_y),
|
||
|
otherwise the shape should be (N, 4),
|
||
|
and the last dimension 4 represent
|
||
|
(coord_x, coord_y, stride_w, stride_h).
|
||
|
"""
|
||
|
|
||
|
assert self.num_levels == len(featmap_sizes)
|
||
|
multi_level_priors = []
|
||
|
for i in range(self.num_levels):
|
||
|
priors = self.single_level_grid_priors(
|
||
|
featmap_sizes[i],
|
||
|
level_idx=i,
|
||
|
dtype=dtype,
|
||
|
device=device,
|
||
|
with_stride=with_stride)
|
||
|
multi_level_priors.append(priors)
|
||
|
return multi_level_priors
|
||
|
|
||
|
def single_level_grid_priors(self,
|
||
|
featmap_size,
|
||
|
level_idx,
|
||
|
dtype=torch.float32,
|
||
|
device='cuda',
|
||
|
with_stride=False):
|
||
|
"""Generate grid Points of a single level.
|
||
|
|
||
|
Note:
|
||
|
This function is usually called by method ``self.grid_priors``.
|
||
|
|
||
|
Args:
|
||
|
featmap_size (tuple[int]): Size of the feature maps, arrange as
|
||
|
(h, w).
|
||
|
level_idx (int): The index of corresponding feature map level.
|
||
|
dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32.
|
||
|
device (str, optional): The device the tensor will be put on.
|
||
|
Defaults to 'cuda'.
|
||
|
with_stride (bool): Concatenate the stride to the last dimension
|
||
|
of points.
|
||
|
|
||
|
Return:
|
||
|
Tensor: Points of single feature levels.
|
||
|
The shape of tensor should be (N, 2) when with stride is
|
||
|
``False``, where N = width * height, width and height
|
||
|
are the sizes of the corresponding feature level,
|
||
|
and the last dimension 2 represent (coord_x, coord_y),
|
||
|
otherwise the shape should be (N, 4),
|
||
|
and the last dimension 4 represent
|
||
|
(coord_x, coord_y, stride_w, stride_h).
|
||
|
"""
|
||
|
feat_h, feat_w = featmap_size
|
||
|
stride_w, stride_h = self.strides[level_idx]
|
||
|
shift_x = (torch.arange(0, feat_w, device=device) +
|
||
|
self.offset) * stride_w
|
||
|
# keep featmap_size as Tensor instead of int, so that we
|
||
|
# can convert to ONNX correctly
|
||
|
shift_x = shift_x.to(dtype)
|
||
|
|
||
|
shift_y = (torch.arange(0, feat_h, device=device) +
|
||
|
self.offset) * stride_h
|
||
|
# keep featmap_size as Tensor instead of int, so that we
|
||
|
# can convert to ONNX correctly
|
||
|
shift_y = shift_y.to(dtype)
|
||
|
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
|
||
|
if not with_stride:
|
||
|
shifts = torch.stack([shift_xx, shift_yy], dim=-1)
|
||
|
else:
|
||
|
# use `shape[0]` instead of `len(shift_xx)` for ONNX export
|
||
|
stride_w = shift_xx.new_full((shift_xx.shape[0], ),
|
||
|
stride_w).to(dtype)
|
||
|
stride_h = shift_xx.new_full((shift_yy.shape[0], ),
|
||
|
stride_h).to(dtype)
|
||
|
shifts = torch.stack([shift_xx, shift_yy, stride_w, stride_h],
|
||
|
dim=-1)
|
||
|
all_points = shifts.to(device)
|
||
|
return all_points
|
||
|
|
||
|
def valid_flags(self, featmap_sizes, pad_shape, device='cuda'):
|
||
|
"""Generate valid flags of points of multiple feature levels.
|
||
|
|
||
|
Args:
|
||
|
featmap_sizes (list(tuple)): List of feature map sizes in
|
||
|
multiple feature levels, each size arrange as
|
||
|
as (h, w).
|
||
|
pad_shape (tuple(int)): The padded shape of the image,
|
||
|
arrange as (h, w).
|
||
|
device (str): The device where the anchors will be put on.
|
||
|
|
||
|
Return:
|
||
|
list(torch.Tensor): Valid flags of points of multiple levels.
|
||
|
"""
|
||
|
assert self.num_levels == len(featmap_sizes)
|
||
|
multi_level_flags = []
|
||
|
for i in range(self.num_levels):
|
||
|
point_stride = self.strides[i]
|
||
|
feat_h, feat_w = featmap_sizes[i]
|
||
|
h, w = pad_shape[:2]
|
||
|
valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h)
|
||
|
valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w)
|
||
|
flags = self.single_level_valid_flags((feat_h, feat_w),
|
||
|
(valid_feat_h, valid_feat_w),
|
||
|
device=device)
|
||
|
multi_level_flags.append(flags)
|
||
|
return multi_level_flags
|
||
|
|
||
|
def single_level_valid_flags(self,
|
||
|
featmap_size,
|
||
|
valid_size,
|
||
|
device='cuda'):
|
||
|
"""Generate the valid flags of points of a single feature map.
|
||
|
|
||
|
Args:
|
||
|
featmap_size (tuple[int]): The size of feature maps, arrange as
|
||
|
as (h, w).
|
||
|
valid_size (tuple[int]): The valid size of the feature maps.
|
||
|
The size arrange as as (h, w).
|
||
|
device (str, optional): The device where the flags will be put on.
|
||
|
Defaults to 'cuda'.
|
||
|
|
||
|
Returns:
|
||
|
torch.Tensor: The valid flags of each points in a single level \
|
||
|
feature map.
|
||
|
"""
|
||
|
feat_h, feat_w = featmap_size
|
||
|
valid_h, valid_w = valid_size
|
||
|
assert valid_h <= feat_h and valid_w <= feat_w
|
||
|
valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
|
||
|
valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
|
||
|
valid_x[:valid_w] = 1
|
||
|
valid_y[:valid_h] = 1
|
||
|
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
|
||
|
valid = valid_xx & valid_yy
|
||
|
return valid
|
||
|
|
||
|
def sparse_priors(self,
|
||
|
prior_idxs,
|
||
|
featmap_size,
|
||
|
level_idx,
|
||
|
dtype=torch.float32,
|
||
|
device='cuda'):
|
||
|
"""Generate sparse points according to the ``prior_idxs``.
|
||
|
|
||
|
Args:
|
||
|
prior_idxs (Tensor): The index of corresponding anchors
|
||
|
in the feature map.
|
||
|
featmap_size (tuple[int]): feature map size arrange as (w, h).
|
||
|
level_idx (int): The level index of corresponding feature
|
||
|
map.
|
||
|
dtype (obj:`torch.dtype`): Date type of points. Defaults to
|
||
|
``torch.float32``.
|
||
|
device (obj:`torch.device`): The device where the points is
|
||
|
located.
|
||
|
Returns:
|
||
|
Tensor: Anchor with shape (N, 2), N should be equal to
|
||
|
the length of ``prior_idxs``. And last dimension
|
||
|
2 represent (coord_x, coord_y).
|
||
|
"""
|
||
|
height, width = featmap_size
|
||
|
x = (prior_idxs % width + self.offset) * self.strides[level_idx][0]
|
||
|
y = ((prior_idxs // width) % height +
|
||
|
self.offset) * self.strides[level_idx][1]
|
||
|
prioris = torch.stack([x, y], 1).to(dtype)
|
||
|
prioris = prioris.to(device)
|
||
|
return prioris
|