yolo_head.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # Copyright (c) 2019 Western Digital Corporation or its affiliates.
  3. import copy
  4. import warnings
  5. from typing import List, Optional, Sequence, Tuple
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from mmcv.cnn import ConvModule, is_norm
  10. from mmengine.model import bias_init_with_prob, constant_init, normal_init
  11. from mmengine.structures import InstanceData
  12. from torch import Tensor
  13. from mmdet.registry import MODELS, TASK_UTILS
  14. from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
  15. OptInstanceList)
  16. from ..task_modules.samplers import PseudoSampler
  17. from ..utils import filter_scores_and_topk, images_to_levels, multi_apply
  18. from .base_dense_head import BaseDenseHead
  19. @MODELS.register_module()
  20. class YOLOV3Head(BaseDenseHead):
  21. """YOLOV3Head Paper link: https://arxiv.org/abs/1804.02767.
  22. Args:
  23. num_classes (int): The number of object classes (w/o background)
  24. in_channels (Sequence[int]): Number of input channels per scale.
  25. out_channels (Sequence[int]): The number of output channels per scale
  26. before the final 1x1 layer. Default: (1024, 512, 256).
  27. anchor_generator (:obj:`ConfigDict` or dict): Config dict for anchor
  28. generator.
  29. bbox_coder (:obj:`ConfigDict` or dict): Config of bounding box coder.
  30. featmap_strides (Sequence[int]): The stride of each scale.
  31. Should be in descending order. Defaults to (32, 16, 8).
  32. one_hot_smoother (float): Set a non-zero value to enable label-smooth
  33. Defaults to 0.
  34. conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
  35. convolution layer. Defaults to None.
  36. norm_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and
  37. config norm layer. Defaults to dict(type='BN', requires_grad=True).
  38. act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
  39. Defaults to dict(type='LeakyReLU', negative_slope=0.1).
  40. loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
  41. loss_conf (:obj:`ConfigDict` or dict): Config of confidence loss.
  42. loss_xy (:obj:`ConfigDict` or dict): Config of xy coordinate loss.
  43. loss_wh (:obj:`ConfigDict` or dict): Config of wh coordinate loss.
  44. train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
  45. YOLOV3 head. Defaults to None.
  46. test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
  47. YOLOV3 head. Defaults to None.
  48. """
  49. def __init__(self,
  50. num_classes: int,
  51. in_channels: Sequence[int],
  52. out_channels: Sequence[int] = (1024, 512, 256),
  53. anchor_generator: ConfigType = dict(
  54. type='YOLOAnchorGenerator',
  55. base_sizes=[[(116, 90), (156, 198), (373, 326)],
  56. [(30, 61), (62, 45), (59, 119)],
  57. [(10, 13), (16, 30), (33, 23)]],
  58. strides=[32, 16, 8]),
  59. bbox_coder: ConfigType = dict(type='YOLOBBoxCoder'),
  60. featmap_strides: Sequence[int] = (32, 16, 8),
  61. one_hot_smoother: float = 0.,
  62. conv_cfg: OptConfigType = None,
  63. norm_cfg: ConfigType = dict(type='BN', requires_grad=True),
  64. act_cfg: ConfigType = dict(
  65. type='LeakyReLU', negative_slope=0.1),
  66. loss_cls: ConfigType = dict(
  67. type='CrossEntropyLoss',
  68. use_sigmoid=True,
  69. loss_weight=1.0),
  70. loss_conf: ConfigType = dict(
  71. type='CrossEntropyLoss',
  72. use_sigmoid=True,
  73. loss_weight=1.0),
  74. loss_xy: ConfigType = dict(
  75. type='CrossEntropyLoss',
  76. use_sigmoid=True,
  77. loss_weight=1.0),
  78. loss_wh: ConfigType = dict(type='MSELoss', loss_weight=1.0),
  79. train_cfg: OptConfigType = None,
  80. test_cfg: OptConfigType = None) -> None:
  81. super().__init__(init_cfg=None)
  82. # Check params
  83. assert (len(in_channels) == len(out_channels) == len(featmap_strides))
  84. self.num_classes = num_classes
  85. self.in_channels = in_channels
  86. self.out_channels = out_channels
  87. self.featmap_strides = featmap_strides
  88. self.train_cfg = train_cfg
  89. self.test_cfg = test_cfg
  90. if self.train_cfg:
  91. self.assigner = TASK_UTILS.build(self.train_cfg['assigner'])
  92. if train_cfg.get('sampler', None) is not None:
  93. self.sampler = TASK_UTILS.build(
  94. self.train_cfg['sampler'], context=self)
  95. else:
  96. self.sampler = PseudoSampler()
  97. self.one_hot_smoother = one_hot_smoother
  98. self.conv_cfg = conv_cfg
  99. self.norm_cfg = norm_cfg
  100. self.act_cfg = act_cfg
  101. self.bbox_coder = TASK_UTILS.build(bbox_coder)
  102. self.prior_generator = TASK_UTILS.build(anchor_generator)
  103. self.loss_cls = MODELS.build(loss_cls)
  104. self.loss_conf = MODELS.build(loss_conf)
  105. self.loss_xy = MODELS.build(loss_xy)
  106. self.loss_wh = MODELS.build(loss_wh)
  107. self.num_base_priors = self.prior_generator.num_base_priors[0]
  108. assert len(
  109. self.prior_generator.num_base_priors) == len(featmap_strides)
  110. self._init_layers()
  111. @property
  112. def num_levels(self) -> int:
  113. """int: number of feature map levels"""
  114. return len(self.featmap_strides)
  115. @property
  116. def num_attrib(self) -> int:
  117. """int: number of attributes in pred_map, bboxes (4) +
  118. objectness (1) + num_classes"""
  119. return 5 + self.num_classes
  120. def _init_layers(self) -> None:
  121. """initialize conv layers in YOLOv3 head."""
  122. self.convs_bridge = nn.ModuleList()
  123. self.convs_pred = nn.ModuleList()
  124. for i in range(self.num_levels):
  125. conv_bridge = ConvModule(
  126. self.in_channels[i],
  127. self.out_channels[i],
  128. 3,
  129. padding=1,
  130. conv_cfg=self.conv_cfg,
  131. norm_cfg=self.norm_cfg,
  132. act_cfg=self.act_cfg)
  133. conv_pred = nn.Conv2d(self.out_channels[i],
  134. self.num_base_priors * self.num_attrib, 1)
  135. self.convs_bridge.append(conv_bridge)
  136. self.convs_pred.append(conv_pred)
  137. def init_weights(self) -> None:
  138. """initialize weights."""
  139. for m in self.modules():
  140. if isinstance(m, nn.Conv2d):
  141. normal_init(m, mean=0, std=0.01)
  142. if is_norm(m):
  143. constant_init(m, 1)
  144. # Use prior in model initialization to improve stability
  145. for conv_pred, stride in zip(self.convs_pred, self.featmap_strides):
  146. bias = conv_pred.bias.reshape(self.num_base_priors, -1)
  147. # init objectness with prior of 8 objects per feature map
  148. # refer to https://github.com/ultralytics/yolov3
  149. nn.init.constant_(bias.data[:, 4],
  150. bias_init_with_prob(8 / (608 / stride)**2))
  151. nn.init.constant_(bias.data[:, 5:], bias_init_with_prob(0.01))
  152. def forward(self, x: Tuple[Tensor, ...]) -> tuple:
  153. """Forward features from the upstream network.
  154. Args:
  155. x (tuple[Tensor]): Features from the upstream network, each is
  156. a 4D-tensor.
  157. Returns:
  158. tuple[Tensor]: A tuple of multi-level predication map, each is a
  159. 4D-tensor of shape (batch_size, 5+num_classes, height, width).
  160. """
  161. assert len(x) == self.num_levels
  162. pred_maps = []
  163. for i in range(self.num_levels):
  164. feat = x[i]
  165. feat = self.convs_bridge[i](feat)
  166. pred_map = self.convs_pred[i](feat)
  167. pred_maps.append(pred_map)
  168. return tuple(pred_maps),
  169. def predict_by_feat(self,
  170. pred_maps: Sequence[Tensor],
  171. batch_img_metas: Optional[List[dict]],
  172. cfg: OptConfigType = None,
  173. rescale: bool = False,
  174. with_nms: bool = True) -> InstanceList:
  175. """Transform a batch of output features extracted from the head into
  176. bbox results. It has been accelerated since PR #5991.
  177. Args:
  178. pred_maps (Sequence[Tensor]): Raw predictions for a batch of
  179. images.
  180. batch_img_metas (list[dict], Optional): Batch image meta info.
  181. Defaults to None.
  182. cfg (:obj:`ConfigDict` or dict, optional): Test / postprocessing
  183. configuration, if None, test_cfg would be used.
  184. Defaults to None.
  185. rescale (bool): If True, return boxes in original image space.
  186. Defaults to False.
  187. with_nms (bool): If True, do nms before return boxes.
  188. Defaults to True.
  189. Returns:
  190. list[:obj:`InstanceData`]: Object detection results of each image
  191. after the post process. Each item usually contains following keys.
  192. - scores (Tensor): Classification scores, has a shape
  193. (num_instance, )
  194. - labels (Tensor): Labels of bboxes, has a shape
  195. (num_instances, ).
  196. - bboxes (Tensor): Has a shape (num_instances, 4),
  197. the last dimension 4 arrange as (x1, y1, x2, y2).
  198. """
  199. assert len(pred_maps) == self.num_levels
  200. cfg = self.test_cfg if cfg is None else cfg
  201. cfg = copy.deepcopy(cfg)
  202. num_imgs = len(batch_img_metas)
  203. featmap_sizes = [pred_map.shape[-2:] for pred_map in pred_maps]
  204. mlvl_anchors = self.prior_generator.grid_priors(
  205. featmap_sizes, device=pred_maps[0].device)
  206. flatten_preds = []
  207. flatten_strides = []
  208. for pred, stride in zip(pred_maps, self.featmap_strides):
  209. pred = pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
  210. self.num_attrib)
  211. pred[..., :2].sigmoid_()
  212. flatten_preds.append(pred)
  213. flatten_strides.append(
  214. pred.new_tensor(stride).expand(pred.size(1)))
  215. flatten_preds = torch.cat(flatten_preds, dim=1)
  216. flatten_bbox_preds = flatten_preds[..., :4]
  217. flatten_objectness = flatten_preds[..., 4].sigmoid()
  218. flatten_cls_scores = flatten_preds[..., 5:].sigmoid()
  219. flatten_anchors = torch.cat(mlvl_anchors)
  220. flatten_strides = torch.cat(flatten_strides)
  221. flatten_bboxes = self.bbox_coder.decode(flatten_anchors,
  222. flatten_bbox_preds,
  223. flatten_strides.unsqueeze(-1))
  224. results_list = []
  225. for (bboxes, scores, objectness,
  226. img_meta) in zip(flatten_bboxes, flatten_cls_scores,
  227. flatten_objectness, batch_img_metas):
  228. # Filtering out all predictions with conf < conf_thr
  229. conf_thr = cfg.get('conf_thr', -1)
  230. if conf_thr > 0:
  231. conf_inds = objectness >= conf_thr
  232. bboxes = bboxes[conf_inds, :]
  233. scores = scores[conf_inds, :]
  234. objectness = objectness[conf_inds]
  235. score_thr = cfg.get('score_thr', 0)
  236. nms_pre = cfg.get('nms_pre', -1)
  237. scores, labels, keep_idxs, _ = filter_scores_and_topk(
  238. scores, score_thr, nms_pre)
  239. results = InstanceData(
  240. scores=scores,
  241. labels=labels,
  242. bboxes=bboxes[keep_idxs],
  243. score_factors=objectness[keep_idxs],
  244. )
  245. results = self._bbox_post_process(
  246. results=results,
  247. cfg=cfg,
  248. rescale=rescale,
  249. with_nms=with_nms,
  250. img_meta=img_meta)
  251. results_list.append(results)
  252. return results_list
  253. def loss_by_feat(
  254. self,
  255. pred_maps: Sequence[Tensor],
  256. batch_gt_instances: InstanceList,
  257. batch_img_metas: List[dict],
  258. batch_gt_instances_ignore: OptInstanceList = None) -> dict:
  259. """Calculate the loss based on the features extracted by the detection
  260. head.
  261. Args:
  262. pred_maps (list[Tensor]): Prediction map for each scale level,
  263. shape (N, num_anchors * num_attrib, H, W)
  264. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  265. gt_instance. It usually includes ``bboxes`` and ``labels``
  266. attributes.
  267. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  268. image size, scaling factor, etc.
  269. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  270. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  271. data that is ignored during training and testing.
  272. Defaults to None.
  273. Returns:
  274. dict: A dictionary of loss components.
  275. """
  276. num_imgs = len(batch_img_metas)
  277. device = pred_maps[0][0].device
  278. featmap_sizes = [
  279. pred_maps[i].shape[-2:] for i in range(self.num_levels)
  280. ]
  281. mlvl_anchors = self.prior_generator.grid_priors(
  282. featmap_sizes, device=device)
  283. anchor_list = [mlvl_anchors for _ in range(num_imgs)]
  284. responsible_flag_list = []
  285. for img_id in range(num_imgs):
  286. responsible_flag_list.append(
  287. self.responsible_flags(featmap_sizes,
  288. batch_gt_instances[img_id].bboxes,
  289. device))
  290. target_maps_list, neg_maps_list = self.get_targets(
  291. anchor_list, responsible_flag_list, batch_gt_instances)
  292. losses_cls, losses_conf, losses_xy, losses_wh = multi_apply(
  293. self.loss_by_feat_single, pred_maps, target_maps_list,
  294. neg_maps_list)
  295. return dict(
  296. loss_cls=losses_cls,
  297. loss_conf=losses_conf,
  298. loss_xy=losses_xy,
  299. loss_wh=losses_wh)
  300. def loss_by_feat_single(self, pred_map: Tensor, target_map: Tensor,
  301. neg_map: Tensor) -> tuple:
  302. """Calculate the loss of a single scale level based on the features
  303. extracted by the detection head.
  304. Args:
  305. pred_map (Tensor): Raw predictions for a single level.
  306. target_map (Tensor): The Ground-Truth target for a single level.
  307. neg_map (Tensor): The negative masks for a single level.
  308. Returns:
  309. tuple:
  310. loss_cls (Tensor): Classification loss.
  311. loss_conf (Tensor): Confidence loss.
  312. loss_xy (Tensor): Regression loss of x, y coordinate.
  313. loss_wh (Tensor): Regression loss of w, h coordinate.
  314. """
  315. num_imgs = len(pred_map)
  316. pred_map = pred_map.permute(0, 2, 3,
  317. 1).reshape(num_imgs, -1, self.num_attrib)
  318. neg_mask = neg_map.float()
  319. pos_mask = target_map[..., 4]
  320. pos_and_neg_mask = neg_mask + pos_mask
  321. pos_mask = pos_mask.unsqueeze(dim=-1)
  322. if torch.max(pos_and_neg_mask) > 1.:
  323. warnings.warn('There is overlap between pos and neg sample.')
  324. pos_and_neg_mask = pos_and_neg_mask.clamp(min=0., max=1.)
  325. pred_xy = pred_map[..., :2]
  326. pred_wh = pred_map[..., 2:4]
  327. pred_conf = pred_map[..., 4]
  328. pred_label = pred_map[..., 5:]
  329. target_xy = target_map[..., :2]
  330. target_wh = target_map[..., 2:4]
  331. target_conf = target_map[..., 4]
  332. target_label = target_map[..., 5:]
  333. loss_cls = self.loss_cls(pred_label, target_label, weight=pos_mask)
  334. loss_conf = self.loss_conf(
  335. pred_conf, target_conf, weight=pos_and_neg_mask)
  336. loss_xy = self.loss_xy(pred_xy, target_xy, weight=pos_mask)
  337. loss_wh = self.loss_wh(pred_wh, target_wh, weight=pos_mask)
  338. return loss_cls, loss_conf, loss_xy, loss_wh
  339. def get_targets(self, anchor_list: List[List[Tensor]],
  340. responsible_flag_list: List[List[Tensor]],
  341. batch_gt_instances: List[InstanceData]) -> tuple:
  342. """Compute target maps for anchors in multiple images.
  343. Args:
  344. anchor_list (list[list[Tensor]]): Multi level anchors of each
  345. image. The outer list indicates images, and the inner list
  346. corresponds to feature levels of the image. Each element of
  347. the inner list is a tensor of shape (num_total_anchors, 4).
  348. responsible_flag_list (list[list[Tensor]]): Multi level responsible
  349. flags of each image. Each element is a tensor of shape
  350. (num_total_anchors, )
  351. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  352. gt_instance. It usually includes ``bboxes`` and ``labels``
  353. attributes.
  354. Returns:
  355. tuple: Usually returns a tuple containing learning targets.
  356. - target_map_list (list[Tensor]): Target map of each level.
  357. - neg_map_list (list[Tensor]): Negative map of each level.
  358. """
  359. num_imgs = len(anchor_list)
  360. # anchor number of multi levels
  361. num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
  362. results = multi_apply(self._get_targets_single, anchor_list,
  363. responsible_flag_list, batch_gt_instances)
  364. all_target_maps, all_neg_maps = results
  365. assert num_imgs == len(all_target_maps) == len(all_neg_maps)
  366. target_maps_list = images_to_levels(all_target_maps, num_level_anchors)
  367. neg_maps_list = images_to_levels(all_neg_maps, num_level_anchors)
  368. return target_maps_list, neg_maps_list
  369. def _get_targets_single(self, anchors: List[Tensor],
  370. responsible_flags: List[Tensor],
  371. gt_instances: InstanceData) -> tuple:
  372. """Generate matching bounding box prior and converted GT.
  373. Args:
  374. anchors (List[Tensor]): Multi-level anchors of the image.
  375. responsible_flags (List[Tensor]): Multi-level responsible flags of
  376. anchors
  377. gt_instances (:obj:`InstanceData`): Ground truth of instance
  378. annotations. It should includes ``bboxes`` and ``labels``
  379. attributes.
  380. Returns:
  381. tuple:
  382. target_map (Tensor): Predication target map of each
  383. scale level, shape (num_total_anchors,
  384. 5+num_classes)
  385. neg_map (Tensor): Negative map of each scale level,
  386. shape (num_total_anchors,)
  387. """
  388. gt_bboxes = gt_instances.bboxes
  389. gt_labels = gt_instances.labels
  390. anchor_strides = []
  391. for i in range(len(anchors)):
  392. anchor_strides.append(
  393. torch.tensor(self.featmap_strides[i],
  394. device=gt_bboxes.device).repeat(len(anchors[i])))
  395. concat_anchors = torch.cat(anchors)
  396. concat_responsible_flags = torch.cat(responsible_flags)
  397. anchor_strides = torch.cat(anchor_strides)
  398. assert len(anchor_strides) == len(concat_anchors) == \
  399. len(concat_responsible_flags)
  400. pred_instances = InstanceData(
  401. priors=concat_anchors, responsible_flags=concat_responsible_flags)
  402. assign_result = self.assigner.assign(pred_instances, gt_instances)
  403. sampling_result = self.sampler.sample(assign_result, pred_instances,
  404. gt_instances)
  405. target_map = concat_anchors.new_zeros(
  406. concat_anchors.size(0), self.num_attrib)
  407. target_map[sampling_result.pos_inds, :4] = self.bbox_coder.encode(
  408. sampling_result.pos_priors, sampling_result.pos_gt_bboxes,
  409. anchor_strides[sampling_result.pos_inds])
  410. target_map[sampling_result.pos_inds, 4] = 1
  411. gt_labels_one_hot = F.one_hot(
  412. gt_labels, num_classes=self.num_classes).float()
  413. if self.one_hot_smoother != 0: # label smooth
  414. gt_labels_one_hot = gt_labels_one_hot * (
  415. 1 - self.one_hot_smoother
  416. ) + self.one_hot_smoother / self.num_classes
  417. target_map[sampling_result.pos_inds, 5:] = gt_labels_one_hot[
  418. sampling_result.pos_assigned_gt_inds]
  419. neg_map = concat_anchors.new_zeros(
  420. concat_anchors.size(0), dtype=torch.uint8)
  421. neg_map[sampling_result.neg_inds] = 1
  422. return target_map, neg_map
  423. def responsible_flags(self, featmap_sizes: List[tuple], gt_bboxes: Tensor,
  424. device: str) -> List[Tensor]:
  425. """Generate responsible anchor flags of grid cells in multiple scales.
  426. Args:
  427. featmap_sizes (List[tuple]): List of feature map sizes in multiple
  428. feature levels.
  429. gt_bboxes (Tensor): Ground truth boxes, shape (n, 4).
  430. device (str): Device where the anchors will be put on.
  431. Return:
  432. List[Tensor]: responsible flags of anchors in multiple level
  433. """
  434. assert self.num_levels == len(featmap_sizes)
  435. multi_level_responsible_flags = []
  436. for i in range(self.num_levels):
  437. anchor_stride = self.prior_generator.strides[i]
  438. feat_h, feat_w = featmap_sizes[i]
  439. gt_cx = ((gt_bboxes[:, 0] + gt_bboxes[:, 2]) * 0.5).to(device)
  440. gt_cy = ((gt_bboxes[:, 1] + gt_bboxes[:, 3]) * 0.5).to(device)
  441. gt_grid_x = torch.floor(gt_cx / anchor_stride[0]).long()
  442. gt_grid_y = torch.floor(gt_cy / anchor_stride[1]).long()
  443. # row major indexing
  444. gt_bboxes_grid_idx = gt_grid_y * feat_w + gt_grid_x
  445. responsible_grid = torch.zeros(
  446. feat_h * feat_w, dtype=torch.uint8, device=device)
  447. responsible_grid[gt_bboxes_grid_idx] = 1
  448. responsible_grid = responsible_grid[:, None].expand(
  449. responsible_grid.size(0),
  450. self.prior_generator.num_base_priors[i]).contiguous().view(-1)
  451. multi_level_responsible_flags.append(responsible_grid)
  452. return multi_level_responsible_flags