2023-12-12 18:51:32 +07:00

151 lines
4.9 KiB
Executable File

import torch
import torch.nn as nn
import numpy as np
from .backbone import CSPDarknet
from .neck import YOLOXPAFPN
from .bbox_head import YOLOXHead
from .transform import DetectorDataPipeline, AutoRotateDetectorDataPipeline
from .factory import _get as get_version
def bbox2result(bboxes, labels, num_classes):
"""Convert detection results to a list of numpy arrays.
bboxes (torch.Tensor | np.ndarray): shape (n, 5)
labels (torch.Tensor | np.ndarray): shape (n, )
num_classes (int): class number, including background class
list(ndarray): bbox results of each class
if bboxes.shape[0] == 0:
return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)]
if isinstance(bboxes, torch.Tensor):
bboxes = bboxes.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()
return [bboxes[labels == i, :] for i in range(num_classes)]
def normalize_bbox(bboxes, scale):
for i in range(len(bboxes)):
bboxes[i][:,:4] /= scale
return bboxes
class SingleStageDetector(nn.Module):
def __init__(self,
super(SingleStageDetector, self).__init__()
assert 'cpu' in device or 'cuda' in device
checkpoint = get_version(version)
pt = torch.load(checkpoint, 'cpu')
self.pipeline = DetectorDataPipeline(**pt['pipeline_args'], device=device)
self.backbone = CSPDarknet(**pt['backbone_args'])
self.neck = YOLOXPAFPN(**pt['neck_args'])
self.bbox_head = YOLOXHead(**pt['bbox_head_args'])
self.load_state_dict(pt['state_dict'], strict=True)
for param in self.parameters():
param.requires_grad = False
self =
print(f'Text detection load from version {version}.')
def extract_feat(self, img):
"""Directly extract features from the backbone + neck."""
x = self.backbone(img)
x = self.neck(x)
return x
def forward(self, img):
"""Test function without test-time augmentation.
img (np.ndarray): Images with shape (H, W, C) or
img (str): Path to image.
list[list[np.ndarray]]: BBox results of each image and classes.
The list corresponds to each class.
img, origin_shape, new_shape = self.pipeline(img)
scale = min(new_shape / origin_shape)
feat = self.extract_feat(img)
results_list = self.bbox_head.simple_test_bboxes(feat)
bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
for det_bboxes, det_labels in results_list
bbox_results = normalize_bbox(bbox_results, scale)
return bbox_results
class AutoRotateDetector(nn.Module):
def __init__(self,
super(AutoRotateDetector, self).__init__()
assert 'cpu' in device or 'cuda' in device
checkpoint = get_version(version)
pt = torch.load(checkpoint, 'cpu')
self.pipeline = AutoRotateDetectorDataPipeline(**pt['pipeline_args'], device=device)
self.backbone = CSPDarknet(**pt['backbone_args'])
self.neck = YOLOXPAFPN(**pt['neck_args'])
self.bbox_head = YOLOXHead(**pt['bbox_head_args'], nms_score_thr=0.8)
self.load_state_dict(pt['state_dict'], strict=True)
for param in self.parameters():
param.requires_grad = False
self =
print(f'Auto rotate detector load from version {version}.')
def extract_feat(self, img):
"""Directly extract features from the backbone + neck."""
x = self.backbone(img)
x = self.neck(x)
return x
def forward(self, img):
"""Test function without test-time augmentation.
img (np.ndarray): Images with shape (H, W, C) or
img (str): Path to image.
np.ndarray: Straight rotated image.
imgs, imgs_np = self.pipeline(img)
maxCount = -1
maxCountRot = None
for idx, img in enumerate(imgs):
currentCount = 0
feat = self.extract_feat(img)
results_list = self.bbox_head.simple_test_bboxes(feat)
bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
for det_bboxes, det_labels in results_list
for class_result in bbox_results:
currentCount += len(class_result)
if currentCount > maxCount:
maxCount = currentCount
maxCountRot = idx
return imgs_np[maxCountRot]