yolact_head.py 49 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. from typing import List, Optional
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from mmcv.cnn import ConvModule
  9. from mmengine.model import BaseModule, ModuleList
  10. from mmengine.structures import InstanceData
  11. from torch import Tensor
  12. from mmdet.registry import MODELS
  13. from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
  14. OptInstanceList, OptMultiConfig)
  15. from ..layers import fast_nms
  16. from ..utils import images_to_levels, multi_apply, select_single_mlvl
  17. from ..utils.misc import empty_instances
  18. from .anchor_head import AnchorHead
  19. from .base_mask_head import BaseMaskHead
  20. @MODELS.register_module()
  21. class YOLACTHead(AnchorHead):
  22. """YOLACT box head used in https://arxiv.org/abs/1904.02689.
  23. Note that YOLACT head is a light version of RetinaNet head.
  24. Four differences are described as follows:
  25. 1. YOLACT box head has three-times fewer anchors.
  26. 2. YOLACT box head shares the convs for box and cls branches.
  27. 3. YOLACT box head uses OHEM instead of Focal loss.
  28. 4. YOLACT box head predicts a set of mask coefficients for each box.
  29. Args:
  30. num_classes (int): Number of categories excluding the background
  31. category.
  32. in_channels (int): Number of channels in the input feature map.
  33. anchor_generator (:obj:`ConfigDict` or dict): Config dict for
  34. anchor generator
  35. loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
  36. loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
  37. num_head_convs (int): Number of the conv layers shared by
  38. box and cls branches.
  39. num_protos (int): Number of the mask coefficients.
  40. use_ohem (bool): If true, ``loss_single_OHEM`` will be used for
  41. cls loss calculation. If false, ``loss_single`` will be used.
  42. conv_cfg (:obj:`ConfigDict` or dict, optional): Dictionary to
  43. construct and config conv layer.
  44. norm_cfg (:obj:`ConfigDict` or dict, optional): Dictionary to
  45. construct and config norm layer.
  46. init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
  47. list[dict], optional): Initialization config dict.
  48. """
  49. def __init__(self,
  50. num_classes: int,
  51. in_channels: int,
  52. anchor_generator: ConfigType = dict(
  53. type='AnchorGenerator',
  54. octave_base_scale=3,
  55. scales_per_octave=1,
  56. ratios=[0.5, 1.0, 2.0],
  57. strides=[8, 16, 32, 64, 128]),
  58. loss_cls: ConfigType = dict(
  59. type='CrossEntropyLoss',
  60. use_sigmoid=False,
  61. reduction='none',
  62. loss_weight=1.0),
  63. loss_bbox: ConfigType = dict(
  64. type='SmoothL1Loss', beta=1.0, loss_weight=1.5),
  65. num_head_convs: int = 1,
  66. num_protos: int = 32,
  67. use_ohem: bool = True,
  68. conv_cfg: OptConfigType = None,
  69. norm_cfg: OptConfigType = None,
  70. init_cfg: OptMultiConfig = dict(
  71. type='Xavier',
  72. distribution='uniform',
  73. bias=0,
  74. layer='Conv2d'),
  75. **kwargs) -> None:
  76. self.num_head_convs = num_head_convs
  77. self.num_protos = num_protos
  78. self.use_ohem = use_ohem
  79. self.conv_cfg = conv_cfg
  80. self.norm_cfg = norm_cfg
  81. super().__init__(
  82. num_classes=num_classes,
  83. in_channels=in_channels,
  84. loss_cls=loss_cls,
  85. loss_bbox=loss_bbox,
  86. anchor_generator=anchor_generator,
  87. init_cfg=init_cfg,
  88. **kwargs)
  89. def _init_layers(self) -> None:
  90. """Initialize layers of the head."""
  91. self.relu = nn.ReLU(inplace=True)
  92. self.head_convs = ModuleList()
  93. for i in range(self.num_head_convs):
  94. chn = self.in_channels if i == 0 else self.feat_channels
  95. self.head_convs.append(
  96. ConvModule(
  97. chn,
  98. self.feat_channels,
  99. 3,
  100. stride=1,
  101. padding=1,
  102. conv_cfg=self.conv_cfg,
  103. norm_cfg=self.norm_cfg))
  104. self.conv_cls = nn.Conv2d(
  105. self.feat_channels,
  106. self.num_base_priors * self.cls_out_channels,
  107. 3,
  108. padding=1)
  109. self.conv_reg = nn.Conv2d(
  110. self.feat_channels, self.num_base_priors * 4, 3, padding=1)
  111. self.conv_coeff = nn.Conv2d(
  112. self.feat_channels,
  113. self.num_base_priors * self.num_protos,
  114. 3,
  115. padding=1)
  116. def forward_single(self, x: Tensor) -> tuple:
  117. """Forward feature of a single scale level.
  118. Args:
  119. x (Tensor): Features of a single scale level.
  120. Returns:
  121. tuple:
  122. - cls_score (Tensor): Cls scores for a single scale level
  123. the channels number is num_anchors * num_classes.
  124. - bbox_pred (Tensor): Box energies / deltas for a single scale
  125. level, the channels number is num_anchors * 4.
  126. - coeff_pred (Tensor): Mask coefficients for a single scale
  127. level, the channels number is num_anchors * num_protos.
  128. """
  129. for head_conv in self.head_convs:
  130. x = head_conv(x)
  131. cls_score = self.conv_cls(x)
  132. bbox_pred = self.conv_reg(x)
  133. coeff_pred = self.conv_coeff(x).tanh()
  134. return cls_score, bbox_pred, coeff_pred
  135. def loss_by_feat(
  136. self,
  137. cls_scores: List[Tensor],
  138. bbox_preds: List[Tensor],
  139. coeff_preds: List[Tensor],
  140. batch_gt_instances: InstanceList,
  141. batch_img_metas: List[dict],
  142. batch_gt_instances_ignore: OptInstanceList = None) -> dict:
  143. """Calculate the loss based on the features extracted by the bbox head.
  144. When ``self.use_ohem == True``, it functions like ``SSDHead.loss``,
  145. otherwise, it follows ``AnchorHead.loss``.
  146. Args:
  147. cls_scores (list[Tensor]): Box scores for each scale level
  148. has shape (N, num_anchors * num_classes, H, W).
  149. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  150. level with shape (N, num_anchors * 4, H, W).
  151. coeff_preds (list[Tensor]): Mask coefficients for each scale
  152. level with shape (N, num_anchors * num_protos, H, W)
  153. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  154. gt_instance. It usually includes ``bboxes`` and ``labels``
  155. attributes.
  156. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  157. image size, scaling factor, etc.
  158. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  159. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  160. data that is ignored during training and testing.
  161. Defaults to None.
  162. Returns:
  163. dict: A dictionary of loss components.
  164. """
  165. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  166. assert len(featmap_sizes) == self.prior_generator.num_levels
  167. device = cls_scores[0].device
  168. anchor_list, valid_flag_list = self.get_anchors(
  169. featmap_sizes, batch_img_metas, device=device)
  170. cls_reg_targets = self.get_targets(
  171. anchor_list,
  172. valid_flag_list,
  173. batch_gt_instances,
  174. batch_img_metas,
  175. batch_gt_instances_ignore=batch_gt_instances_ignore,
  176. unmap_outputs=not self.use_ohem,
  177. return_sampling_results=True)
  178. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  179. avg_factor, sampling_results) = cls_reg_targets
  180. if self.use_ohem:
  181. num_images = len(batch_img_metas)
  182. all_cls_scores = torch.cat([
  183. s.permute(0, 2, 3, 1).reshape(
  184. num_images, -1, self.cls_out_channels) for s in cls_scores
  185. ], 1)
  186. all_labels = torch.cat(labels_list, -1).view(num_images, -1)
  187. all_label_weights = torch.cat(label_weights_list,
  188. -1).view(num_images, -1)
  189. all_bbox_preds = torch.cat([
  190. b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
  191. for b in bbox_preds
  192. ], -2)
  193. all_bbox_targets = torch.cat(bbox_targets_list,
  194. -2).view(num_images, -1, 4)
  195. all_bbox_weights = torch.cat(bbox_weights_list,
  196. -2).view(num_images, -1, 4)
  197. # concat all level anchors to a single tensor
  198. all_anchors = []
  199. for i in range(num_images):
  200. all_anchors.append(torch.cat(anchor_list[i]))
  201. # check NaN and Inf
  202. assert torch.isfinite(all_cls_scores).all().item(), \
  203. 'classification scores become infinite or NaN!'
  204. assert torch.isfinite(all_bbox_preds).all().item(), \
  205. 'bbox predications become infinite or NaN!'
  206. losses_cls, losses_bbox = multi_apply(
  207. self.OHEMloss_by_feat_single,
  208. all_cls_scores,
  209. all_bbox_preds,
  210. all_anchors,
  211. all_labels,
  212. all_label_weights,
  213. all_bbox_targets,
  214. all_bbox_weights,
  215. avg_factor=avg_factor)
  216. else:
  217. # anchor number of multi levels
  218. num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
  219. # concat all level anchors and flags to a single tensor
  220. concat_anchor_list = []
  221. for i in range(len(anchor_list)):
  222. concat_anchor_list.append(torch.cat(anchor_list[i]))
  223. all_anchor_list = images_to_levels(concat_anchor_list,
  224. num_level_anchors)
  225. losses_cls, losses_bbox = multi_apply(
  226. self.loss_by_feat_single,
  227. cls_scores,
  228. bbox_preds,
  229. all_anchor_list,
  230. labels_list,
  231. label_weights_list,
  232. bbox_targets_list,
  233. bbox_weights_list,
  234. avg_factor=avg_factor)
  235. losses = dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
  236. # update `_raw_positive_infos`, which will be used when calling
  237. # `get_positive_infos`.
  238. self._raw_positive_infos.update(coeff_preds=coeff_preds)
  239. return losses
  240. def OHEMloss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor,
  241. anchors: Tensor, labels: Tensor,
  242. label_weights: Tensor, bbox_targets: Tensor,
  243. bbox_weights: Tensor,
  244. avg_factor: int) -> tuple:
  245. """Compute loss of a single image. Similar to
  246. func:``SSDHead.loss_by_feat_single``
  247. Args:
  248. cls_score (Tensor): Box scores for eachimage
  249. Has shape (num_total_anchors, num_classes).
  250. bbox_pred (Tensor): Box energies / deltas for each image
  251. level with shape (num_total_anchors, 4).
  252. anchors (Tensor): Box reference for each scale level with shape
  253. (num_total_anchors, 4).
  254. labels (Tensor): Labels of each anchors with shape
  255. (num_total_anchors,).
  256. label_weights (Tensor): Label weights of each anchor with shape
  257. (num_total_anchors,)
  258. bbox_targets (Tensor): BBox regression targets of each anchor
  259. weight shape (num_total_anchors, 4).
  260. bbox_weights (Tensor): BBox regression loss weights of each anchor
  261. with shape (num_total_anchors, 4).
  262. avg_factor (int): Average factor that is used to average
  263. the loss. When using sampling method, avg_factor is usually
  264. the sum of positive and negative priors. When using
  265. `PseudoSampler`, `avg_factor` is usually equal to the number
  266. of positive priors.
  267. Returns:
  268. Tuple[Tensor, Tensor]: A tuple of cls loss and bbox loss of one
  269. feature map.
  270. """
  271. loss_cls_all = self.loss_cls(cls_score, labels, label_weights)
  272. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  273. pos_inds = ((labels >= 0) & (labels < self.num_classes)).nonzero(
  274. as_tuple=False).reshape(-1)
  275. neg_inds = (labels == self.num_classes).nonzero(
  276. as_tuple=False).view(-1)
  277. num_pos_samples = pos_inds.size(0)
  278. if num_pos_samples == 0:
  279. num_neg_samples = neg_inds.size(0)
  280. else:
  281. num_neg_samples = self.train_cfg['neg_pos_ratio'] * \
  282. num_pos_samples
  283. if num_neg_samples > neg_inds.size(0):
  284. num_neg_samples = neg_inds.size(0)
  285. topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
  286. loss_cls_pos = loss_cls_all[pos_inds].sum()
  287. loss_cls_neg = topk_loss_cls_neg.sum()
  288. loss_cls = (loss_cls_pos + loss_cls_neg) / avg_factor
  289. if self.reg_decoded_bbox:
  290. # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
  291. # is applied directly on the decoded bounding boxes, it
  292. # decodes the already encoded coordinates to absolute format.
  293. bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
  294. loss_bbox = self.loss_bbox(
  295. bbox_pred, bbox_targets, bbox_weights, avg_factor=avg_factor)
  296. return loss_cls[None], loss_bbox
  297. def get_positive_infos(self) -> InstanceList:
  298. """Get positive information from sampling results.
  299. Returns:
  300. list[:obj:`InstanceData`]: Positive Information of each image,
  301. usually including positive bboxes, positive labels, positive
  302. priors, positive coeffs, etc.
  303. """
  304. assert len(self._raw_positive_infos) > 0
  305. sampling_results = self._raw_positive_infos['sampling_results']
  306. num_imgs = len(sampling_results)
  307. coeff_pred_list = []
  308. for coeff_pred_per_level in self._raw_positive_infos['coeff_preds']:
  309. coeff_pred_per_level = \
  310. coeff_pred_per_level.permute(
  311. 0, 2, 3, 1).reshape(num_imgs, -1, self.num_protos)
  312. coeff_pred_list.append(coeff_pred_per_level)
  313. coeff_preds = torch.cat(coeff_pred_list, dim=1)
  314. pos_info_list = []
  315. for idx, sampling_result in enumerate(sampling_results):
  316. pos_info = InstanceData()
  317. coeff_preds_single = coeff_preds[idx]
  318. pos_info.pos_assigned_gt_inds = \
  319. sampling_result.pos_assigned_gt_inds
  320. pos_info.pos_inds = sampling_result.pos_inds
  321. pos_info.coeffs = coeff_preds_single[sampling_result.pos_inds]
  322. pos_info.bboxes = sampling_result.pos_gt_bboxes
  323. pos_info_list.append(pos_info)
  324. return pos_info_list
  325. def predict_by_feat(self,
  326. cls_scores,
  327. bbox_preds,
  328. coeff_preds,
  329. batch_img_metas,
  330. cfg=None,
  331. rescale=True,
  332. **kwargs):
  333. """Similar to func:``AnchorHead.get_bboxes``, but additionally
  334. processes coeff_preds.
  335. Args:
  336. cls_scores (list[Tensor]): Box scores for each scale level
  337. with shape (N, num_anchors * num_classes, H, W)
  338. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  339. level with shape (N, num_anchors * 4, H, W)
  340. coeff_preds (list[Tensor]): Mask coefficients for each scale
  341. level with shape (N, num_anchors * num_protos, H, W)
  342. batch_img_metas (list[dict]): Batch image meta info.
  343. cfg (:obj:`Config` | None): Test / postprocessing configuration,
  344. if None, test_cfg would be used
  345. rescale (bool): If True, return boxes in original image space.
  346. Defaults to True.
  347. Returns:
  348. list[:obj:`InstanceData`]: Object detection results of each image
  349. after the post process. Each item usually contains following keys.
  350. - scores (Tensor): Classification scores, has a shape
  351. (num_instance, )
  352. - labels (Tensor): Labels of bboxes, has a shape
  353. (num_instances, ).
  354. - bboxes (Tensor): Has a shape (num_instances, 4),
  355. the last dimension 4 arrange as (x1, y1, x2, y2).
  356. - coeffs (Tensor): the predicted mask coefficients of
  357. instance inside the corresponding box has a shape
  358. (n, num_protos).
  359. """
  360. assert len(cls_scores) == len(bbox_preds)
  361. num_levels = len(cls_scores)
  362. device = cls_scores[0].device
  363. featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
  364. mlvl_priors = self.prior_generator.grid_priors(
  365. featmap_sizes, device=device)
  366. result_list = []
  367. for img_id in range(len(batch_img_metas)):
  368. img_meta = batch_img_metas[img_id]
  369. cls_score_list = select_single_mlvl(cls_scores, img_id)
  370. bbox_pred_list = select_single_mlvl(bbox_preds, img_id)
  371. coeff_pred_list = select_single_mlvl(coeff_preds, img_id)
  372. results = self._predict_by_feat_single(
  373. cls_score_list=cls_score_list,
  374. bbox_pred_list=bbox_pred_list,
  375. coeff_preds_list=coeff_pred_list,
  376. mlvl_priors=mlvl_priors,
  377. img_meta=img_meta,
  378. cfg=cfg,
  379. rescale=rescale)
  380. result_list.append(results)
  381. return result_list
  382. def _predict_by_feat_single(self,
  383. cls_score_list: List[Tensor],
  384. bbox_pred_list: List[Tensor],
  385. coeff_preds_list: List[Tensor],
  386. mlvl_priors: List[Tensor],
  387. img_meta: dict,
  388. cfg: ConfigType,
  389. rescale: bool = True) -> InstanceData:
  390. """Transform a single image's features extracted from the head into
  391. bbox results. Similar to func:``AnchorHead._predict_by_feat_single``,
  392. but additionally processes coeff_preds_list and uses fast NMS instead
  393. of traditional NMS.
  394. Args:
  395. cls_score_list (list[Tensor]): Box scores for a single scale level
  396. Has shape (num_priors * num_classes, H, W).
  397. bbox_pred_list (list[Tensor]): Box energies / deltas for a single
  398. scale level with shape (num_priors * 4, H, W).
  399. coeff_preds_list (list[Tensor]): Mask coefficients for a single
  400. scale level with shape (num_priors * num_protos, H, W).
  401. mlvl_priors (list[Tensor]): Each element in the list is
  402. the priors of a single level in feature pyramid,
  403. has shape (num_priors, 4).
  404. img_meta (dict): Image meta info.
  405. cfg (mmengine.Config): Test / postprocessing configuration,
  406. if None, test_cfg would be used.
  407. rescale (bool): If True, return boxes in original image space.
  408. Defaults to False.
  409. Returns:
  410. :obj:`InstanceData`: Detection results of each image
  411. after the post process.
  412. Each item usually contains following keys.
  413. - scores (Tensor): Classification scores, has a shape
  414. (num_instance, )
  415. - labels (Tensor): Labels of bboxes, has a shape
  416. (num_instances, ).
  417. - bboxes (Tensor): Has a shape (num_instances, 4),
  418. the last dimension 4 arrange as (x1, y1, x2, y2).
  419. - coeffs (Tensor): the predicted mask coefficients of
  420. instance inside the corresponding box has a shape
  421. (n, num_protos).
  422. """
  423. assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_priors)
  424. cfg = self.test_cfg if cfg is None else cfg
  425. cfg = copy.deepcopy(cfg)
  426. img_shape = img_meta['img_shape']
  427. nms_pre = cfg.get('nms_pre', -1)
  428. mlvl_bbox_preds = []
  429. mlvl_valid_priors = []
  430. mlvl_scores = []
  431. mlvl_coeffs = []
  432. for cls_score, bbox_pred, coeff_pred, priors in \
  433. zip(cls_score_list, bbox_pred_list,
  434. coeff_preds_list, mlvl_priors):
  435. assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
  436. cls_score = cls_score.permute(1, 2,
  437. 0).reshape(-1, self.cls_out_channels)
  438. if self.use_sigmoid_cls:
  439. scores = cls_score.sigmoid()
  440. else:
  441. scores = cls_score.softmax(-1)
  442. bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
  443. coeff_pred = coeff_pred.permute(1, 2,
  444. 0).reshape(-1, self.num_protos)
  445. if 0 < nms_pre < scores.shape[0]:
  446. # Get maximum scores for foreground classes.
  447. if self.use_sigmoid_cls:
  448. max_scores, _ = scores.max(dim=1)
  449. else:
  450. # remind that we set FG labels to [0, num_class-1]
  451. # since mmdet v2.0
  452. # BG cat_id: num_class
  453. max_scores, _ = scores[:, :-1].max(dim=1)
  454. _, topk_inds = max_scores.topk(nms_pre)
  455. priors = priors[topk_inds, :]
  456. bbox_pred = bbox_pred[topk_inds, :]
  457. scores = scores[topk_inds, :]
  458. coeff_pred = coeff_pred[topk_inds, :]
  459. mlvl_bbox_preds.append(bbox_pred)
  460. mlvl_valid_priors.append(priors)
  461. mlvl_scores.append(scores)
  462. mlvl_coeffs.append(coeff_pred)
  463. bbox_pred = torch.cat(mlvl_bbox_preds)
  464. priors = torch.cat(mlvl_valid_priors)
  465. multi_bboxes = self.bbox_coder.decode(
  466. priors, bbox_pred, max_shape=img_shape)
  467. multi_scores = torch.cat(mlvl_scores)
  468. multi_coeffs = torch.cat(mlvl_coeffs)
  469. return self._bbox_post_process(
  470. multi_bboxes=multi_bboxes,
  471. multi_scores=multi_scores,
  472. multi_coeffs=multi_coeffs,
  473. cfg=cfg,
  474. rescale=rescale,
  475. img_meta=img_meta)
  476. def _bbox_post_process(self,
  477. multi_bboxes: Tensor,
  478. multi_scores: Tensor,
  479. multi_coeffs: Tensor,
  480. cfg: ConfigType,
  481. rescale: bool = False,
  482. img_meta: Optional[dict] = None,
  483. **kwargs) -> InstanceData:
  484. """bbox post-processing method.
  485. The boxes would be rescaled to the original image scale and do
  486. the nms operation. Usually `with_nms` is False is used for aug test.
  487. Args:
  488. multi_bboxes (Tensor): Predicted bbox that concat all levels.
  489. multi_scores (Tensor): Bbox scores that concat all levels.
  490. multi_coeffs (Tensor): Mask coefficients that concat all levels.
  491. cfg (ConfigDict): Test / postprocessing configuration,
  492. if None, test_cfg would be used.
  493. rescale (bool): If True, return boxes in original image space.
  494. Default to False.
  495. img_meta (dict, optional): Image meta info. Defaults to None.
  496. Returns:
  497. :obj:`InstanceData`: Detection results of each image
  498. after the post process.
  499. Each item usually contains following keys.
  500. - scores (Tensor): Classification scores, has a shape
  501. (num_instance, )
  502. - labels (Tensor): Labels of bboxes, has a shape
  503. (num_instances, ).
  504. - bboxes (Tensor): Has a shape (num_instances, 4),
  505. the last dimension 4 arrange as (x1, y1, x2, y2).
  506. - coeffs (Tensor): the predicted mask coefficients of
  507. instance inside the corresponding box has a shape
  508. (n, num_protos).
  509. """
  510. if rescale:
  511. assert img_meta.get('scale_factor') is not None
  512. multi_bboxes /= multi_bboxes.new_tensor(
  513. img_meta['scale_factor']).repeat((1, 2))
  514. # mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
  515. if self.use_sigmoid_cls:
  516. # Add a dummy background class to the backend when using sigmoid
  517. # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
  518. # BG cat_id: num_class
  519. padding = multi_scores.new_zeros(multi_scores.shape[0], 1)
  520. multi_scores = torch.cat([multi_scores, padding], dim=1)
  521. det_bboxes, det_labels, det_coeffs = fast_nms(
  522. multi_bboxes, multi_scores, multi_coeffs, cfg.score_thr,
  523. cfg.iou_thr, cfg.top_k, cfg.max_per_img)
  524. results = InstanceData()
  525. results.bboxes = det_bboxes[:, :4]
  526. results.scores = det_bboxes[:, -1]
  527. results.labels = det_labels
  528. results.coeffs = det_coeffs
  529. return results
  530. @MODELS.register_module()
  531. class YOLACTProtonet(BaseMaskHead):
  532. """YOLACT mask head used in https://arxiv.org/abs/1904.02689.
  533. This head outputs the mask prototypes for YOLACT.
  534. Args:
  535. in_channels (int): Number of channels in the input feature map.
  536. proto_channels (tuple[int]): Output channels of protonet convs.
  537. proto_kernel_sizes (tuple[int]): Kernel sizes of protonet convs.
  538. include_last_relu (bool): If keep the last relu of protonet.
  539. num_protos (int): Number of prototypes.
  540. num_classes (int): Number of categories excluding the background
  541. category.
  542. loss_mask_weight (float): Reweight the mask loss by this factor.
  543. max_masks_to_train (int): Maximum number of masks to train for
  544. each image.
  545. with_seg_branch (bool): Whether to apply a semantic segmentation
  546. branch and calculate loss during training to increase
  547. performance with no speed penalty. Defaults to True.
  548. loss_segm (:obj:`ConfigDict` or dict, optional): Config of
  549. semantic segmentation loss.
  550. train_cfg (:obj:`ConfigDict` or dict, optional): Training config
  551. of head.
  552. test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
  553. head.
  554. init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
  555. list[dict], optional): Initialization config dict.
  556. """
  557. def __init__(
  558. self,
  559. num_classes: int,
  560. in_channels: int = 256,
  561. proto_channels: tuple = (256, 256, 256, None, 256, 32),
  562. proto_kernel_sizes: tuple = (3, 3, 3, -2, 3, 1),
  563. include_last_relu: bool = True,
  564. num_protos: int = 32,
  565. loss_mask_weight: float = 1.0,
  566. max_masks_to_train: int = 100,
  567. train_cfg: OptConfigType = None,
  568. test_cfg: OptConfigType = None,
  569. with_seg_branch: bool = True,
  570. loss_segm: ConfigType = dict(
  571. type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
  572. init_cfg=dict(
  573. type='Xavier',
  574. distribution='uniform',
  575. override=dict(name='protonet'))
  576. ) -> None:
  577. super().__init__(init_cfg=init_cfg)
  578. self.in_channels = in_channels
  579. self.proto_channels = proto_channels
  580. self.proto_kernel_sizes = proto_kernel_sizes
  581. self.include_last_relu = include_last_relu
  582. # Segmentation branch
  583. self.with_seg_branch = with_seg_branch
  584. self.segm_branch = SegmentationModule(
  585. num_classes=num_classes, in_channels=in_channels) \
  586. if with_seg_branch else None
  587. self.loss_segm = MODELS.build(loss_segm) if with_seg_branch else None
  588. self.loss_mask_weight = loss_mask_weight
  589. self.num_protos = num_protos
  590. self.num_classes = num_classes
  591. self.max_masks_to_train = max_masks_to_train
  592. self.train_cfg = train_cfg
  593. self.test_cfg = test_cfg
  594. self._init_layers()
  595. def _init_layers(self) -> None:
  596. """Initialize layers of the head."""
  597. # Possible patterns:
  598. # ( 256, 3) -> conv
  599. # ( 256,-2) -> deconv
  600. # (None,-2) -> bilinear interpolate
  601. in_channels = self.in_channels
  602. protonets = ModuleList()
  603. for num_channels, kernel_size in zip(self.proto_channels,
  604. self.proto_kernel_sizes):
  605. if kernel_size > 0:
  606. layer = nn.Conv2d(
  607. in_channels,
  608. num_channels,
  609. kernel_size,
  610. padding=kernel_size // 2)
  611. else:
  612. if num_channels is None:
  613. layer = InterpolateModule(
  614. scale_factor=-kernel_size,
  615. mode='bilinear',
  616. align_corners=False)
  617. else:
  618. layer = nn.ConvTranspose2d(
  619. in_channels,
  620. num_channels,
  621. -kernel_size,
  622. padding=kernel_size // 2)
  623. protonets.append(layer)
  624. protonets.append(nn.ReLU(inplace=True))
  625. in_channels = num_channels if num_channels is not None \
  626. else in_channels
  627. if not self.include_last_relu:
  628. protonets = protonets[:-1]
  629. self.protonet = nn.Sequential(*protonets)
  630. def forward(self, x: tuple, positive_infos: InstanceList) -> tuple:
  631. """Forward feature from the upstream network to get prototypes and
  632. linearly combine the prototypes, using masks coefficients, into
  633. instance masks. Finally, crop the instance masks with given bboxes.
  634. Args:
  635. x (Tuple[Tensor]): Feature from the upstream network, which is
  636. a 4D-tensor.
  637. positive_infos (List[:obj:``InstanceData``]): Positive information
  638. that calculate from detect head.
  639. Returns:
  640. tuple: Predicted instance segmentation masks and
  641. semantic segmentation map.
  642. """
  643. # YOLACT used single feature map to get segmentation masks
  644. single_x = x[0]
  645. # YOLACT segmentation branch, if not training or segmentation branch
  646. # is None, will not process the forward function.
  647. if self.segm_branch is not None and self.training:
  648. segm_preds = self.segm_branch(single_x)
  649. else:
  650. segm_preds = None
  651. # YOLACT mask head
  652. prototypes = self.protonet(single_x)
  653. prototypes = prototypes.permute(0, 2, 3, 1).contiguous()
  654. num_imgs = single_x.size(0)
  655. mask_pred_list = []
  656. for idx in range(num_imgs):
  657. cur_prototypes = prototypes[idx]
  658. pos_coeffs = positive_infos[idx].coeffs
  659. # Linearly combine the prototypes with the mask coefficients
  660. mask_preds = cur_prototypes @ pos_coeffs.t()
  661. mask_preds = torch.sigmoid(mask_preds)
  662. mask_pred_list.append(mask_preds)
  663. return mask_pred_list, segm_preds
  664. def loss_by_feat(self, mask_preds: List[Tensor], segm_preds: List[Tensor],
  665. batch_gt_instances: InstanceList,
  666. batch_img_metas: List[dict], positive_infos: InstanceList,
  667. **kwargs) -> dict:
  668. """Calculate the loss based on the features extracted by the mask head.
  669. Args:
  670. mask_preds (list[Tensor]): List of predicted prototypes, each has
  671. shape (num_classes, H, W).
  672. segm_preds (Tensor): Predicted semantic segmentation map with
  673. shape (N, num_classes, H, W)
  674. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  675. gt_instance. It usually includes ``bboxes``, ``masks``,
  676. and ``labels`` attributes.
  677. batch_img_metas (list[dict]): Meta information of multiple images.
  678. positive_infos (List[:obj:``InstanceData``]): Information of
  679. positive samples of each image that are assigned in detection
  680. head.
  681. Returns:
  682. dict[str, Tensor]: A dictionary of loss components.
  683. """
  684. assert positive_infos is not None, \
  685. 'positive_infos should not be None in `YOLACTProtonet`'
  686. losses = dict()
  687. # crop
  688. croped_mask_pred = self.crop_mask_preds(mask_preds, batch_img_metas,
  689. positive_infos)
  690. loss_mask = []
  691. loss_segm = []
  692. num_imgs, _, mask_h, mask_w = segm_preds.size()
  693. assert num_imgs == len(croped_mask_pred)
  694. segm_avg_factor = num_imgs * mask_h * mask_w
  695. total_pos = 0
  696. if self.segm_branch is not None:
  697. assert segm_preds is not None
  698. for idx in range(num_imgs):
  699. img_meta = batch_img_metas[idx]
  700. (mask_preds, pos_mask_targets, segm_targets, num_pos,
  701. gt_bboxes_for_reweight) = self._get_targets_single(
  702. croped_mask_pred[idx], segm_preds[idx],
  703. batch_gt_instances[idx], positive_infos[idx])
  704. # segmentation loss
  705. if self.with_seg_branch:
  706. if segm_targets is None:
  707. loss = segm_preds[idx].sum() * 0.
  708. else:
  709. loss = self.loss_segm(
  710. segm_preds[idx],
  711. segm_targets,
  712. avg_factor=segm_avg_factor)
  713. loss_segm.append(loss)
  714. # mask loss
  715. total_pos += num_pos
  716. if num_pos == 0 or pos_mask_targets is None:
  717. loss = mask_preds.sum() * 0.
  718. else:
  719. mask_preds = torch.clamp(mask_preds, 0, 1)
  720. loss = F.binary_cross_entropy(
  721. mask_preds, pos_mask_targets,
  722. reduction='none') * self.loss_mask_weight
  723. h, w = img_meta['img_shape'][:2]
  724. gt_bboxes_width = (gt_bboxes_for_reweight[:, 2] -
  725. gt_bboxes_for_reweight[:, 0]) / w
  726. gt_bboxes_height = (gt_bboxes_for_reweight[:, 3] -
  727. gt_bboxes_for_reweight[:, 1]) / h
  728. loss = loss.mean(dim=(1,
  729. 2)) / gt_bboxes_width / gt_bboxes_height
  730. loss = torch.sum(loss)
  731. loss_mask.append(loss)
  732. if total_pos == 0:
  733. total_pos += 1 # avoid nan
  734. loss_mask = [x / total_pos for x in loss_mask]
  735. losses.update(loss_mask=loss_mask)
  736. if self.with_seg_branch:
  737. losses.update(loss_segm=loss_segm)
  738. return losses
  739. def _get_targets_single(self, mask_preds: Tensor, segm_pred: Tensor,
  740. gt_instances: InstanceData,
  741. positive_info: InstanceData):
  742. """Compute targets for predictions of single image.
  743. Args:
  744. mask_preds (Tensor): Predicted prototypes with shape
  745. (num_classes, H, W).
  746. segm_pred (Tensor): Predicted semantic segmentation map
  747. with shape (num_classes, H, W).
  748. gt_instances (:obj:`InstanceData`): Ground truth of instance
  749. annotations. It should includes ``bboxes``, ``labels``,
  750. and ``masks`` attributes.
  751. positive_info (:obj:`InstanceData`): Information of positive
  752. samples that are assigned in detection head. It usually
  753. contains following keys.
  754. - pos_assigned_gt_inds (Tensor): Assigner GT indexes of
  755. positive proposals, has shape (num_pos, )
  756. - pos_inds (Tensor): Positive index of image, has
  757. shape (num_pos, ).
  758. - coeffs (Tensor): Positive mask coefficients
  759. with shape (num_pos, num_protos).
  760. - bboxes (Tensor): Positive bboxes with shape
  761. (num_pos, 4)
  762. Returns:
  763. tuple: Usually returns a tuple containing learning targets.
  764. - mask_preds (Tensor): Positive predicted mask with shape
  765. (num_pos, mask_h, mask_w).
  766. - pos_mask_targets (Tensor): Positive mask targets with shape
  767. (num_pos, mask_h, mask_w).
  768. - segm_targets (Tensor): Semantic segmentation targets with shape
  769. (num_classes, segm_h, segm_w).
  770. - num_pos (int): Positive numbers.
  771. - gt_bboxes_for_reweight (Tensor): GT bboxes that match to the
  772. positive priors has shape (num_pos, 4).
  773. """
  774. gt_bboxes = gt_instances.bboxes
  775. gt_labels = gt_instances.labels
  776. device = gt_bboxes.device
  777. gt_masks = gt_instances.masks.to_tensor(
  778. dtype=torch.bool, device=device).float()
  779. if gt_masks.size(0) == 0:
  780. return mask_preds, None, None, 0, None
  781. # process with semantic segmentation targets
  782. if segm_pred is not None:
  783. num_classes, segm_h, segm_w = segm_pred.size()
  784. with torch.no_grad():
  785. downsampled_masks = F.interpolate(
  786. gt_masks.unsqueeze(0), (segm_h, segm_w),
  787. mode='bilinear',
  788. align_corners=False).squeeze(0)
  789. downsampled_masks = downsampled_masks.gt(0.5).float()
  790. segm_targets = torch.zeros_like(segm_pred, requires_grad=False)
  791. for obj_idx in range(downsampled_masks.size(0)):
  792. segm_targets[gt_labels[obj_idx] - 1] = torch.max(
  793. segm_targets[gt_labels[obj_idx] - 1],
  794. downsampled_masks[obj_idx])
  795. else:
  796. segm_targets = None
  797. # process with mask targets
  798. pos_assigned_gt_inds = positive_info.pos_assigned_gt_inds
  799. num_pos = pos_assigned_gt_inds.size(0)
  800. # Since we're producing (near) full image masks,
  801. # it'd take too much vram to backprop on every single mask.
  802. # Thus we select only a subset.
  803. if num_pos > self.max_masks_to_train:
  804. perm = torch.randperm(num_pos)
  805. select = perm[:self.max_masks_to_train]
  806. mask_preds = mask_preds[select]
  807. pos_assigned_gt_inds = pos_assigned_gt_inds[select]
  808. num_pos = self.max_masks_to_train
  809. gt_bboxes_for_reweight = gt_bboxes[pos_assigned_gt_inds]
  810. mask_h, mask_w = mask_preds.shape[-2:]
  811. gt_masks = F.interpolate(
  812. gt_masks.unsqueeze(0), (mask_h, mask_w),
  813. mode='bilinear',
  814. align_corners=False).squeeze(0)
  815. gt_masks = gt_masks.gt(0.5).float()
  816. pos_mask_targets = gt_masks[pos_assigned_gt_inds]
  817. return (mask_preds, pos_mask_targets, segm_targets, num_pos,
  818. gt_bboxes_for_reweight)
  819. def crop_mask_preds(self, mask_preds: List[Tensor],
  820. batch_img_metas: List[dict],
  821. positive_infos: InstanceList) -> list:
  822. """Crop predicted masks by zeroing out everything not in the predicted
  823. bbox.
  824. Args:
  825. mask_preds (list[Tensor]): Predicted prototypes with shape
  826. (num_classes, H, W).
  827. batch_img_metas (list[dict]): Meta information of multiple images.
  828. positive_infos (List[:obj:``InstanceData``]): Positive
  829. information that calculate from detect head.
  830. Returns:
  831. list: The cropped masks.
  832. """
  833. croped_mask_preds = []
  834. for img_meta, mask_preds, cur_info in zip(batch_img_metas, mask_preds,
  835. positive_infos):
  836. bboxes_for_cropping = copy.deepcopy(cur_info.bboxes)
  837. h, w = img_meta['img_shape'][:2]
  838. bboxes_for_cropping[:, 0::2] /= w
  839. bboxes_for_cropping[:, 1::2] /= h
  840. mask_preds = self.crop_single(mask_preds, bboxes_for_cropping)
  841. mask_preds = mask_preds.permute(2, 0, 1).contiguous()
  842. croped_mask_preds.append(mask_preds)
  843. return croped_mask_preds
  844. def crop_single(self,
  845. masks: Tensor,
  846. boxes: Tensor,
  847. padding: int = 1) -> Tensor:
  848. """Crop single predicted masks by zeroing out everything not in the
  849. predicted bbox.
  850. Args:
  851. masks (Tensor): Predicted prototypes, has shape [H, W, N].
  852. boxes (Tensor): Bbox coords in relative point form with
  853. shape [N, 4].
  854. padding (int): Image padding size.
  855. Return:
  856. Tensor: The cropped masks.
  857. """
  858. h, w, n = masks.size()
  859. x1, x2 = self.sanitize_coordinates(
  860. boxes[:, 0], boxes[:, 2], w, padding, cast=False)
  861. y1, y2 = self.sanitize_coordinates(
  862. boxes[:, 1], boxes[:, 3], h, padding, cast=False)
  863. rows = torch.arange(
  864. w, device=masks.device, dtype=x1.dtype).view(1, -1,
  865. 1).expand(h, w, n)
  866. cols = torch.arange(
  867. h, device=masks.device, dtype=x1.dtype).view(-1, 1,
  868. 1).expand(h, w, n)
  869. masks_left = rows >= x1.view(1, 1, -1)
  870. masks_right = rows < x2.view(1, 1, -1)
  871. masks_up = cols >= y1.view(1, 1, -1)
  872. masks_down = cols < y2.view(1, 1, -1)
  873. crop_mask = masks_left * masks_right * masks_up * masks_down
  874. return masks * crop_mask.float()
  875. def sanitize_coordinates(self,
  876. x1: Tensor,
  877. x2: Tensor,
  878. img_size: int,
  879. padding: int = 0,
  880. cast: bool = True) -> tuple:
  881. """Sanitizes the input coordinates so that x1 < x2, x1 != x2, x1 >= 0,
  882. and x2 <= image_size. Also converts from relative to absolute
  883. coordinates and casts the results to long tensors.
  884. Warning: this does things in-place behind the scenes so
  885. copy if necessary.
  886. Args:
  887. x1 (Tensor): shape (N, ).
  888. x2 (Tensor): shape (N, ).
  889. img_size (int): Size of the input image.
  890. padding (int): x1 >= padding, x2 <= image_size-padding.
  891. cast (bool): If cast is false, the result won't be cast to longs.
  892. Returns:
  893. tuple:
  894. - x1 (Tensor): Sanitized _x1.
  895. - x2 (Tensor): Sanitized _x2.
  896. """
  897. x1 = x1 * img_size
  898. x2 = x2 * img_size
  899. if cast:
  900. x1 = x1.long()
  901. x2 = x2.long()
  902. x1 = torch.min(x1, x2)
  903. x2 = torch.max(x1, x2)
  904. x1 = torch.clamp(x1 - padding, min=0)
  905. x2 = torch.clamp(x2 + padding, max=img_size)
  906. return x1, x2
  907. def predict_by_feat(self,
  908. mask_preds: List[Tensor],
  909. segm_preds: Tensor,
  910. results_list: InstanceList,
  911. batch_img_metas: List[dict],
  912. rescale: bool = True,
  913. **kwargs) -> InstanceList:
  914. """Transform a batch of output features extracted from the head into
  915. mask results.
  916. Args:
  917. mask_preds (list[Tensor]): Predicted prototypes with shape
  918. (num_classes, H, W).
  919. results_list (List[:obj:``InstanceData``]): BBoxHead results.
  920. batch_img_metas (list[dict]): Meta information of all images.
  921. rescale (bool, optional): Whether to rescale the results.
  922. Defaults to False.
  923. Returns:
  924. list[:obj:`InstanceData`]: Processed results of multiple
  925. images.Each :obj:`InstanceData` usually contains
  926. following keys.
  927. - scores (Tensor): Classification scores, has shape
  928. (num_instance,).
  929. - labels (Tensor): Has shape (num_instances,).
  930. - masks (Tensor): Processed mask results, has
  931. shape (num_instances, h, w).
  932. """
  933. assert len(mask_preds) == len(results_list) == len(batch_img_metas)
  934. croped_mask_pred = self.crop_mask_preds(mask_preds, batch_img_metas,
  935. results_list)
  936. for img_id in range(len(batch_img_metas)):
  937. img_meta = batch_img_metas[img_id]
  938. results = results_list[img_id]
  939. bboxes = results.bboxes
  940. mask_preds = croped_mask_pred[img_id]
  941. if bboxes.shape[0] == 0 or mask_preds.shape[0] == 0:
  942. results_list[img_id] = empty_instances(
  943. [img_meta],
  944. bboxes.device,
  945. task_type='mask',
  946. instance_results=[results])[0]
  947. else:
  948. im_mask = self._predict_by_feat_single(
  949. mask_preds=croped_mask_pred[img_id],
  950. bboxes=bboxes,
  951. img_meta=img_meta,
  952. rescale=rescale)
  953. results.masks = im_mask
  954. return results_list
  955. def _predict_by_feat_single(self,
  956. mask_preds: Tensor,
  957. bboxes: Tensor,
  958. img_meta: dict,
  959. rescale: bool,
  960. cfg: OptConfigType = None):
  961. """Transform a single image's features extracted from the head into
  962. mask results.
  963. Args:
  964. mask_preds (Tensor): Predicted prototypes, has shape [H, W, N].
  965. bboxes (Tensor): Bbox coords in relative point form with
  966. shape [N, 4].
  967. img_meta (dict): Meta information of each image, e.g.,
  968. image size, scaling factor, etc.
  969. rescale (bool): If rescale is False, then returned masks will
  970. fit the scale of imgs[0].
  971. cfg (dict, optional): Config used in test phase.
  972. Defaults to None.
  973. Returns:
  974. :obj:`InstanceData`: Processed results of single image.
  975. it usually contains following keys.
  976. - scores (Tensor): Classification scores, has shape
  977. (num_instance,).
  978. - labels (Tensor): Has shape (num_instances,).
  979. - masks (Tensor): Processed mask results, has
  980. shape (num_instances, h, w).
  981. """
  982. cfg = self.test_cfg if cfg is None else cfg
  983. scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat(
  984. (1, 2))
  985. img_h, img_w = img_meta['ori_shape'][:2]
  986. if rescale: # in-placed rescale the bboxes
  987. scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat(
  988. (1, 2))
  989. bboxes /= scale_factor
  990. else:
  991. w_scale, h_scale = scale_factor[0, 0], scale_factor[0, 1]
  992. img_h = np.round(img_h * h_scale.item()).astype(np.int32)
  993. img_w = np.round(img_w * w_scale.item()).astype(np.int32)
  994. masks = F.interpolate(
  995. mask_preds.unsqueeze(0), (img_h, img_w),
  996. mode='bilinear',
  997. align_corners=False).squeeze(0) > cfg.mask_thr
  998. if cfg.mask_thr_binary < 0:
  999. # for visualization and debugging
  1000. masks = (masks * 255).to(dtype=torch.uint8)
  1001. return masks
  1002. class SegmentationModule(BaseModule):
  1003. """YOLACT segmentation branch used in <https://arxiv.org/abs/1904.02689>`_
  1004. In mmdet v2.x `segm_loss` is calculated in YOLACTSegmHead, while in
  1005. mmdet v3.x `SegmentationModule` is used to obtain the predicted semantic
  1006. segmentation map and `segm_loss` is calculated in YOLACTProtonet.
  1007. Args:
  1008. num_classes (int): Number of categories excluding the background
  1009. category.
  1010. in_channels (int): Number of channels in the input feature map.
  1011. init_cfg (dict or list[dict], optional): Initialization config dict.
  1012. """
  1013. def __init__(
  1014. self,
  1015. num_classes: int,
  1016. in_channels: int = 256,
  1017. init_cfg: ConfigType = dict(
  1018. type='Xavier',
  1019. distribution='uniform',
  1020. override=dict(name='segm_conv'))
  1021. ) -> None:
  1022. super().__init__(init_cfg=init_cfg)
  1023. self.in_channels = in_channels
  1024. self.num_classes = num_classes
  1025. self._init_layers()
  1026. def _init_layers(self) -> None:
  1027. """Initialize layers of the head."""
  1028. self.segm_conv = nn.Conv2d(
  1029. self.in_channels, self.num_classes, kernel_size=1)
  1030. def forward(self, x: Tensor) -> Tensor:
  1031. """Forward feature from the upstream network.
  1032. Args:
  1033. x (Tensor): Feature from the upstream network, which is
  1034. a 4D-tensor.
  1035. Returns:
  1036. Tensor: Predicted semantic segmentation map with shape
  1037. (N, num_classes, H, W).
  1038. """
  1039. return self.segm_conv(x)
  1040. class InterpolateModule(BaseModule):
  1041. """This is a module version of F.interpolate.
  1042. Any arguments you give it just get passed along for the ride.
  1043. """
  1044. def __init__(self, *args, init_cfg=None, **kwargs) -> None:
  1045. super().__init__(init_cfg=init_cfg)
  1046. self.args = args
  1047. self.kwargs = kwargs
  1048. def forward(self, x: Tensor) -> Tensor:
  1049. """Forward features from the upstream network.
  1050. Args:
  1051. x (Tensor): Feature from the upstream network, which is
  1052. a 4D-tensor.
  1053. Returns:
  1054. Tensor: A 4D-tensor feature map.
  1055. """
  1056. return F.interpolate(x, *self.args, **self.kwargs)