grid_head.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, List, Tuple
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from mmcv.cnn import ConvModule
  8. from mmengine.config import ConfigDict
  9. from mmengine.model import BaseModule
  10. from mmengine.structures import InstanceData
  11. from torch import Tensor
  12. from mmdet.models.task_modules.samplers import SamplingResult
  13. from mmdet.registry import MODELS
  14. from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptConfigType
  15. @MODELS.register_module()
  16. class GridHead(BaseModule):
  17. """Implementation of `Grid Head <https://arxiv.org/abs/1811.12030>`_
  18. Args:
  19. grid_points (int): The number of grid points. Defaults to 9.
  20. num_convs (int): The number of convolution layers. Defaults to 8.
  21. roi_feat_size (int): RoI feature size. Default to 14.
  22. in_channels (int): The channel number of inputs features.
  23. Defaults to 256.
  24. conv_kernel_size (int): The kernel size of convolution layers.
  25. Defaults to 3.
  26. point_feat_channels (int): The number of channels of each point
  27. features. Defaults to 64.
  28. class_agnostic (bool): Whether use class agnostic classification.
  29. If so, the output channels of logits will be 1. Defaults to False.
  30. loss_grid (:obj:`ConfigDict` or dict): Config of grid loss.
  31. conv_cfg (:obj:`ConfigDict` or dict, optional) dictionary to
  32. construct and config conv layer.
  33. norm_cfg (:obj:`ConfigDict` or dict): dictionary to construct and
  34. config norm layer.
  35. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  36. dict]): Initialization config dict.
  37. """
  38. def __init__(
  39. self,
  40. grid_points: int = 9,
  41. num_convs: int = 8,
  42. roi_feat_size: int = 14,
  43. in_channels: int = 256,
  44. conv_kernel_size: int = 3,
  45. point_feat_channels: int = 64,
  46. deconv_kernel_size: int = 4,
  47. class_agnostic: bool = False,
  48. loss_grid: ConfigType = dict(
  49. type='CrossEntropyLoss', use_sigmoid=True, loss_weight=15),
  50. conv_cfg: OptConfigType = None,
  51. norm_cfg: ConfigType = dict(type='GN', num_groups=36),
  52. init_cfg: MultiConfig = [
  53. dict(type='Kaiming', layer=['Conv2d', 'Linear']),
  54. dict(
  55. type='Normal',
  56. layer='ConvTranspose2d',
  57. std=0.001,
  58. override=dict(
  59. type='Normal',
  60. name='deconv2',
  61. std=0.001,
  62. bias=-np.log(0.99 / 0.01)))
  63. ]
  64. ) -> None:
  65. super().__init__(init_cfg=init_cfg)
  66. self.grid_points = grid_points
  67. self.num_convs = num_convs
  68. self.roi_feat_size = roi_feat_size
  69. self.in_channels = in_channels
  70. self.conv_kernel_size = conv_kernel_size
  71. self.point_feat_channels = point_feat_channels
  72. self.conv_out_channels = self.point_feat_channels * self.grid_points
  73. self.class_agnostic = class_agnostic
  74. self.conv_cfg = conv_cfg
  75. self.norm_cfg = norm_cfg
  76. if isinstance(norm_cfg, dict) and norm_cfg['type'] == 'GN':
  77. assert self.conv_out_channels % norm_cfg['num_groups'] == 0
  78. assert self.grid_points >= 4
  79. self.grid_size = int(np.sqrt(self.grid_points))
  80. if self.grid_size * self.grid_size != self.grid_points:
  81. raise ValueError('grid_points must be a square number')
  82. # the predicted heatmap is half of whole_map_size
  83. if not isinstance(self.roi_feat_size, int):
  84. raise ValueError('Only square RoIs are supporeted in Grid R-CNN')
  85. self.whole_map_size = self.roi_feat_size * 4
  86. # compute point-wise sub-regions
  87. self.sub_regions = self.calc_sub_regions()
  88. self.convs = []
  89. for i in range(self.num_convs):
  90. in_channels = (
  91. self.in_channels if i == 0 else self.conv_out_channels)
  92. stride = 2 if i == 0 else 1
  93. padding = (self.conv_kernel_size - 1) // 2
  94. self.convs.append(
  95. ConvModule(
  96. in_channels,
  97. self.conv_out_channels,
  98. self.conv_kernel_size,
  99. stride=stride,
  100. padding=padding,
  101. conv_cfg=self.conv_cfg,
  102. norm_cfg=self.norm_cfg,
  103. bias=True))
  104. self.convs = nn.Sequential(*self.convs)
  105. self.deconv1 = nn.ConvTranspose2d(
  106. self.conv_out_channels,
  107. self.conv_out_channels,
  108. kernel_size=deconv_kernel_size,
  109. stride=2,
  110. padding=(deconv_kernel_size - 2) // 2,
  111. groups=grid_points)
  112. self.norm1 = nn.GroupNorm(grid_points, self.conv_out_channels)
  113. self.deconv2 = nn.ConvTranspose2d(
  114. self.conv_out_channels,
  115. grid_points,
  116. kernel_size=deconv_kernel_size,
  117. stride=2,
  118. padding=(deconv_kernel_size - 2) // 2,
  119. groups=grid_points)
  120. # find the 4-neighbor of each grid point
  121. self.neighbor_points = []
  122. grid_size = self.grid_size
  123. for i in range(grid_size): # i-th column
  124. for j in range(grid_size): # j-th row
  125. neighbors = []
  126. if i > 0: # left: (i - 1, j)
  127. neighbors.append((i - 1) * grid_size + j)
  128. if j > 0: # up: (i, j - 1)
  129. neighbors.append(i * grid_size + j - 1)
  130. if j < grid_size - 1: # down: (i, j + 1)
  131. neighbors.append(i * grid_size + j + 1)
  132. if i < grid_size - 1: # right: (i + 1, j)
  133. neighbors.append((i + 1) * grid_size + j)
  134. self.neighbor_points.append(tuple(neighbors))
  135. # total edges in the grid
  136. self.num_edges = sum([len(p) for p in self.neighbor_points])
  137. self.forder_trans = nn.ModuleList() # first-order feature transition
  138. self.sorder_trans = nn.ModuleList() # second-order feature transition
  139. for neighbors in self.neighbor_points:
  140. fo_trans = nn.ModuleList()
  141. so_trans = nn.ModuleList()
  142. for _ in range(len(neighbors)):
  143. # each transition module consists of a 5x5 depth-wise conv and
  144. # 1x1 conv.
  145. fo_trans.append(
  146. nn.Sequential(
  147. nn.Conv2d(
  148. self.point_feat_channels,
  149. self.point_feat_channels,
  150. 5,
  151. stride=1,
  152. padding=2,
  153. groups=self.point_feat_channels),
  154. nn.Conv2d(self.point_feat_channels,
  155. self.point_feat_channels, 1)))
  156. so_trans.append(
  157. nn.Sequential(
  158. nn.Conv2d(
  159. self.point_feat_channels,
  160. self.point_feat_channels,
  161. 5,
  162. 1,
  163. 2,
  164. groups=self.point_feat_channels),
  165. nn.Conv2d(self.point_feat_channels,
  166. self.point_feat_channels, 1)))
  167. self.forder_trans.append(fo_trans)
  168. self.sorder_trans.append(so_trans)
  169. self.loss_grid = MODELS.build(loss_grid)
  170. def forward(self, x: Tensor) -> Dict[str, Tensor]:
  171. """forward function of ``GridHead``.
  172. Args:
  173. x (Tensor): RoI features, has shape
  174. (num_rois, num_channels, roi_feat_size, roi_feat_size).
  175. Returns:
  176. Dict[str, Tensor]: Return a dict including fused and unfused
  177. heatmap.
  178. """
  179. assert x.shape[-1] == x.shape[-2] == self.roi_feat_size
  180. # RoI feature transformation, downsample 2x
  181. x = self.convs(x)
  182. c = self.point_feat_channels
  183. # first-order fusion
  184. x_fo = [None for _ in range(self.grid_points)]
  185. for i, points in enumerate(self.neighbor_points):
  186. x_fo[i] = x[:, i * c:(i + 1) * c]
  187. for j, point_idx in enumerate(points):
  188. x_fo[i] = x_fo[i] + self.forder_trans[i][j](
  189. x[:, point_idx * c:(point_idx + 1) * c])
  190. # second-order fusion
  191. x_so = [None for _ in range(self.grid_points)]
  192. for i, points in enumerate(self.neighbor_points):
  193. x_so[i] = x[:, i * c:(i + 1) * c]
  194. for j, point_idx in enumerate(points):
  195. x_so[i] = x_so[i] + self.sorder_trans[i][j](x_fo[point_idx])
  196. # predicted heatmap with fused features
  197. x2 = torch.cat(x_so, dim=1)
  198. x2 = self.deconv1(x2)
  199. x2 = F.relu(self.norm1(x2), inplace=True)
  200. heatmap = self.deconv2(x2)
  201. # predicted heatmap with original features (applicable during training)
  202. if self.training:
  203. x1 = x
  204. x1 = self.deconv1(x1)
  205. x1 = F.relu(self.norm1(x1), inplace=True)
  206. heatmap_unfused = self.deconv2(x1)
  207. else:
  208. heatmap_unfused = heatmap
  209. return dict(fused=heatmap, unfused=heatmap_unfused)
  210. def calc_sub_regions(self) -> List[Tuple[float]]:
  211. """Compute point specific representation regions.
  212. See `Grid R-CNN Plus <https://arxiv.org/abs/1906.05688>`_ for details.
  213. """
  214. # to make it consistent with the original implementation, half_size
  215. # is computed as 2 * quarter_size, which is smaller
  216. half_size = self.whole_map_size // 4 * 2
  217. sub_regions = []
  218. for i in range(self.grid_points):
  219. x_idx = i // self.grid_size
  220. y_idx = i % self.grid_size
  221. if x_idx == 0:
  222. sub_x1 = 0
  223. elif x_idx == self.grid_size - 1:
  224. sub_x1 = half_size
  225. else:
  226. ratio = x_idx / (self.grid_size - 1) - 0.25
  227. sub_x1 = max(int(ratio * self.whole_map_size), 0)
  228. if y_idx == 0:
  229. sub_y1 = 0
  230. elif y_idx == self.grid_size - 1:
  231. sub_y1 = half_size
  232. else:
  233. ratio = y_idx / (self.grid_size - 1) - 0.25
  234. sub_y1 = max(int(ratio * self.whole_map_size), 0)
  235. sub_regions.append(
  236. (sub_x1, sub_y1, sub_x1 + half_size, sub_y1 + half_size))
  237. return sub_regions
  238. def get_targets(self, sampling_results: List[SamplingResult],
  239. rcnn_train_cfg: ConfigDict) -> Tensor:
  240. """Calculate the ground truth for all samples in a batch according to
  241. the sampling_results.".
  242. Args:
  243. sampling_results (List[:obj:`SamplingResult`]): Assign results of
  244. all images in a batch after sampling.
  245. rcnn_train_cfg (:obj:`ConfigDict`): `train_cfg` of RCNN.
  246. Returns:
  247. Tensor: Grid heatmap targets.
  248. """
  249. # mix all samples (across images) together.
  250. pos_bboxes = torch.cat([res.pos_bboxes for res in sampling_results],
  251. dim=0).cpu()
  252. pos_gt_bboxes = torch.cat(
  253. [res.pos_gt_bboxes for res in sampling_results], dim=0).cpu()
  254. assert pos_bboxes.shape == pos_gt_bboxes.shape
  255. # expand pos_bboxes to 2x of original size
  256. x1 = pos_bboxes[:, 0] - (pos_bboxes[:, 2] - pos_bboxes[:, 0]) / 2
  257. y1 = pos_bboxes[:, 1] - (pos_bboxes[:, 3] - pos_bboxes[:, 1]) / 2
  258. x2 = pos_bboxes[:, 2] + (pos_bboxes[:, 2] - pos_bboxes[:, 0]) / 2
  259. y2 = pos_bboxes[:, 3] + (pos_bboxes[:, 3] - pos_bboxes[:, 1]) / 2
  260. pos_bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
  261. pos_bbox_ws = (pos_bboxes[:, 2] - pos_bboxes[:, 0]).unsqueeze(-1)
  262. pos_bbox_hs = (pos_bboxes[:, 3] - pos_bboxes[:, 1]).unsqueeze(-1)
  263. num_rois = pos_bboxes.shape[0]
  264. map_size = self.whole_map_size
  265. # this is not the final target shape
  266. targets = torch.zeros((num_rois, self.grid_points, map_size, map_size),
  267. dtype=torch.float)
  268. # pre-compute interpolation factors for all grid points.
  269. # the first item is the factor of x-dim, and the second is y-dim.
  270. # for a 9-point grid, factors are like (1, 0), (0.5, 0.5), (0, 1)
  271. factors = []
  272. for j in range(self.grid_points):
  273. x_idx = j // self.grid_size
  274. y_idx = j % self.grid_size
  275. factors.append((1 - x_idx / (self.grid_size - 1),
  276. 1 - y_idx / (self.grid_size - 1)))
  277. radius = rcnn_train_cfg.pos_radius
  278. radius2 = radius**2
  279. for i in range(num_rois):
  280. # ignore small bboxes
  281. if (pos_bbox_ws[i] <= self.grid_size
  282. or pos_bbox_hs[i] <= self.grid_size):
  283. continue
  284. # for each grid point, mark a small circle as positive
  285. for j in range(self.grid_points):
  286. factor_x, factor_y = factors[j]
  287. gridpoint_x = factor_x * pos_gt_bboxes[i, 0] + (
  288. 1 - factor_x) * pos_gt_bboxes[i, 2]
  289. gridpoint_y = factor_y * pos_gt_bboxes[i, 1] + (
  290. 1 - factor_y) * pos_gt_bboxes[i, 3]
  291. cx = int((gridpoint_x - pos_bboxes[i, 0]) / pos_bbox_ws[i] *
  292. map_size)
  293. cy = int((gridpoint_y - pos_bboxes[i, 1]) / pos_bbox_hs[i] *
  294. map_size)
  295. for x in range(cx - radius, cx + radius + 1):
  296. for y in range(cy - radius, cy + radius + 1):
  297. if x >= 0 and x < map_size and y >= 0 and y < map_size:
  298. if (x - cx)**2 + (y - cy)**2 <= radius2:
  299. targets[i, j, y, x] = 1
  300. # reduce the target heatmap size by a half
  301. # proposed in Grid R-CNN Plus (https://arxiv.org/abs/1906.05688).
  302. sub_targets = []
  303. for i in range(self.grid_points):
  304. sub_x1, sub_y1, sub_x2, sub_y2 = self.sub_regions[i]
  305. sub_targets.append(targets[:, [i], sub_y1:sub_y2, sub_x1:sub_x2])
  306. sub_targets = torch.cat(sub_targets, dim=1)
  307. sub_targets = sub_targets.to(sampling_results[0].pos_bboxes.device)
  308. return sub_targets
  309. def loss(self, grid_pred: Tensor, sample_idx: Tensor,
  310. sampling_results: List[SamplingResult],
  311. rcnn_train_cfg: ConfigDict) -> dict:
  312. """Calculate the loss based on the features extracted by the grid head.
  313. Args:
  314. grid_pred (dict[str, Tensor]): Outputs of grid_head forward.
  315. sample_idx (Tensor): The sampling index of ``grid_pred``.
  316. sampling_results (List[obj:SamplingResult]): Assign results of
  317. all images in a batch after sampling.
  318. rcnn_train_cfg (obj:`ConfigDict`): `train_cfg` of RCNN.
  319. Returns:
  320. dict: A dictionary of loss and targets components.
  321. """
  322. grid_targets = self.get_targets(sampling_results, rcnn_train_cfg)
  323. grid_targets = grid_targets[sample_idx]
  324. loss_fused = self.loss_grid(grid_pred['fused'], grid_targets)
  325. loss_unfused = self.loss_grid(grid_pred['unfused'], grid_targets)
  326. loss_grid = loss_fused + loss_unfused
  327. return dict(loss_grid=loss_grid)
  328. def predict_by_feat(self,
  329. grid_preds: Dict[str, Tensor],
  330. results_list: List[InstanceData],
  331. batch_img_metas: List[dict],
  332. rescale: bool = False) -> InstanceList:
  333. """Adjust the predicted bboxes from bbox head.
  334. Args:
  335. grid_preds (dict[str, Tensor]): dictionary outputted by forward
  336. function.
  337. results_list (list[:obj:`InstanceData`]): Detection results of
  338. each image.
  339. batch_img_metas (list[dict]): List of image information.
  340. rescale (bool): If True, return boxes in original image space.
  341. Defaults to False.
  342. Returns:
  343. list[:obj:`InstanceData`]: Detection results of each image
  344. after the post process. Each item usually contains following keys.
  345. - scores (Tensor): Classification scores, has a shape \
  346. (num_instance, )
  347. - labels (Tensor): Labels of bboxes, has a shape (num_instances, ).
  348. - bboxes (Tensor): Has a shape (num_instances, 4), the last \
  349. dimension 4 arrange as (x1, y1, x2, y2).
  350. """
  351. num_roi_per_img = tuple(res.bboxes.size(0) for res in results_list)
  352. grid_preds = {
  353. k: v.split(num_roi_per_img, 0)
  354. for k, v in grid_preds.items()
  355. }
  356. for i, results in enumerate(results_list):
  357. if len(results) != 0:
  358. bboxes = self._predict_by_feat_single(
  359. grid_pred=grid_preds['fused'][i],
  360. bboxes=results.bboxes,
  361. img_meta=batch_img_metas[i],
  362. rescale=rescale)
  363. results.bboxes = bboxes
  364. return results_list
  365. def _predict_by_feat_single(self,
  366. grid_pred: Tensor,
  367. bboxes: Tensor,
  368. img_meta: dict,
  369. rescale: bool = False) -> Tensor:
  370. """Adjust ``bboxes`` according to ``grid_pred``.
  371. Args:
  372. grid_pred (Tensor): Grid fused heatmap.
  373. bboxes (Tensor): Predicted bboxes, has shape (n, 4)
  374. img_meta (dict): image information.
  375. rescale (bool): If True, return boxes in original image space.
  376. Defaults to False.
  377. Returns:
  378. Tensor: adjusted bboxes.
  379. """
  380. assert bboxes.size(0) == grid_pred.size(0)
  381. grid_pred = grid_pred.sigmoid()
  382. R, c, h, w = grid_pred.shape
  383. half_size = self.whole_map_size // 4 * 2
  384. assert h == w == half_size
  385. assert c == self.grid_points
  386. # find the point with max scores in the half-sized heatmap
  387. grid_pred = grid_pred.view(R * c, h * w)
  388. pred_scores, pred_position = grid_pred.max(dim=1)
  389. xs = pred_position % w
  390. ys = pred_position // w
  391. # get the position in the whole heatmap instead of half-sized heatmap
  392. for i in range(self.grid_points):
  393. xs[i::self.grid_points] += self.sub_regions[i][0]
  394. ys[i::self.grid_points] += self.sub_regions[i][1]
  395. # reshape to (num_rois, grid_points)
  396. pred_scores, xs, ys = tuple(
  397. map(lambda x: x.view(R, c), [pred_scores, xs, ys]))
  398. # get expanded pos_bboxes
  399. widths = (bboxes[:, 2] - bboxes[:, 0]).unsqueeze(-1)
  400. heights = (bboxes[:, 3] - bboxes[:, 1]).unsqueeze(-1)
  401. x1 = (bboxes[:, 0, None] - widths / 2)
  402. y1 = (bboxes[:, 1, None] - heights / 2)
  403. # map the grid point to the absolute coordinates
  404. abs_xs = (xs.float() + 0.5) / w * widths + x1
  405. abs_ys = (ys.float() + 0.5) / h * heights + y1
  406. # get the grid points indices that fall on the bbox boundaries
  407. x1_inds = [i for i in range(self.grid_size)]
  408. y1_inds = [i * self.grid_size for i in range(self.grid_size)]
  409. x2_inds = [
  410. self.grid_points - self.grid_size + i
  411. for i in range(self.grid_size)
  412. ]
  413. y2_inds = [(i + 1) * self.grid_size - 1 for i in range(self.grid_size)]
  414. # voting of all grid points on some boundary
  415. bboxes_x1 = (abs_xs[:, x1_inds] * pred_scores[:, x1_inds]).sum(
  416. dim=1, keepdim=True) / (
  417. pred_scores[:, x1_inds].sum(dim=1, keepdim=True))
  418. bboxes_y1 = (abs_ys[:, y1_inds] * pred_scores[:, y1_inds]).sum(
  419. dim=1, keepdim=True) / (
  420. pred_scores[:, y1_inds].sum(dim=1, keepdim=True))
  421. bboxes_x2 = (abs_xs[:, x2_inds] * pred_scores[:, x2_inds]).sum(
  422. dim=1, keepdim=True) / (
  423. pred_scores[:, x2_inds].sum(dim=1, keepdim=True))
  424. bboxes_y2 = (abs_ys[:, y2_inds] * pred_scores[:, y2_inds]).sum(
  425. dim=1, keepdim=True) / (
  426. pred_scores[:, y2_inds].sum(dim=1, keepdim=True))
  427. bboxes = torch.cat([bboxes_x1, bboxes_y1, bboxes_x2, bboxes_y2], dim=1)
  428. bboxes[:, [0, 2]].clamp_(min=0, max=img_meta['img_shape'][1])
  429. bboxes[:, [1, 3]].clamp_(min=0, max=img_meta['img_shape'][0])
  430. if rescale:
  431. assert img_meta.get('scale_factor') is not None
  432. bboxes /= bboxes.new_tensor(img_meta['scale_factor']).repeat(
  433. (1, 2))
  434. return bboxes