cascade_roi_head.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Sequence, Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. from mmengine.model import ModuleList
  6. from mmengine.structures import InstanceData
  7. from torch import Tensor
  8. from mmdet.models.task_modules.samplers import SamplingResult
  9. from mmdet.models.test_time_augs import merge_aug_masks
  10. from mmdet.registry import MODELS, TASK_UTILS
  11. from mmdet.structures import SampleList
  12. from mmdet.structures.bbox import bbox2roi, get_box_tensor
  13. from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType,
  14. OptMultiConfig)
  15. from ..utils.misc import empty_instances, unpack_gt_instances
  16. from .base_roi_head import BaseRoIHead
  17. @MODELS.register_module()
  18. class CascadeRoIHead(BaseRoIHead):
  19. """Cascade roi head including one bbox head and one mask head.
  20. https://arxiv.org/abs/1712.00726
  21. """
  22. def __init__(self,
  23. num_stages: int,
  24. stage_loss_weights: Union[List[float], Tuple[float]],
  25. bbox_roi_extractor: OptMultiConfig = None,
  26. bbox_head: OptMultiConfig = None,
  27. mask_roi_extractor: OptMultiConfig = None,
  28. mask_head: OptMultiConfig = None,
  29. shared_head: OptConfigType = None,
  30. train_cfg: OptConfigType = None,
  31. test_cfg: OptConfigType = None,
  32. init_cfg: OptMultiConfig = None) -> None:
  33. assert bbox_roi_extractor is not None
  34. assert bbox_head is not None
  35. assert shared_head is None, \
  36. 'Shared head is not supported in Cascade RCNN anymore'
  37. self.num_stages = num_stages
  38. self.stage_loss_weights = stage_loss_weights
  39. super().__init__(
  40. bbox_roi_extractor=bbox_roi_extractor,
  41. bbox_head=bbox_head,
  42. mask_roi_extractor=mask_roi_extractor,
  43. mask_head=mask_head,
  44. shared_head=shared_head,
  45. train_cfg=train_cfg,
  46. test_cfg=test_cfg,
  47. init_cfg=init_cfg)
  48. def init_bbox_head(self, bbox_roi_extractor: MultiConfig,
  49. bbox_head: MultiConfig) -> None:
  50. """Initialize box head and box roi extractor.
  51. Args:
  52. bbox_roi_extractor (:obj:`ConfigDict`, dict or list):
  53. Config of box roi extractor.
  54. bbox_head (:obj:`ConfigDict`, dict or list): Config
  55. of box in box head.
  56. """
  57. self.bbox_roi_extractor = ModuleList()
  58. self.bbox_head = ModuleList()
  59. if not isinstance(bbox_roi_extractor, list):
  60. bbox_roi_extractor = [
  61. bbox_roi_extractor for _ in range(self.num_stages)
  62. ]
  63. if not isinstance(bbox_head, list):
  64. bbox_head = [bbox_head for _ in range(self.num_stages)]
  65. assert len(bbox_roi_extractor) == len(bbox_head) == self.num_stages
  66. for roi_extractor, head in zip(bbox_roi_extractor, bbox_head):
  67. self.bbox_roi_extractor.append(MODELS.build(roi_extractor))
  68. self.bbox_head.append(MODELS.build(head))
  69. def init_mask_head(self, mask_roi_extractor: MultiConfig,
  70. mask_head: MultiConfig) -> None:
  71. """Initialize mask head and mask roi extractor.
  72. Args:
  73. mask_head (dict): Config of mask in mask head.
  74. mask_roi_extractor (:obj:`ConfigDict`, dict or list):
  75. Config of mask roi extractor.
  76. """
  77. self.mask_head = nn.ModuleList()
  78. if not isinstance(mask_head, list):
  79. mask_head = [mask_head for _ in range(self.num_stages)]
  80. assert len(mask_head) == self.num_stages
  81. for head in mask_head:
  82. self.mask_head.append(MODELS.build(head))
  83. if mask_roi_extractor is not None:
  84. self.share_roi_extractor = False
  85. self.mask_roi_extractor = ModuleList()
  86. if not isinstance(mask_roi_extractor, list):
  87. mask_roi_extractor = [
  88. mask_roi_extractor for _ in range(self.num_stages)
  89. ]
  90. assert len(mask_roi_extractor) == self.num_stages
  91. for roi_extractor in mask_roi_extractor:
  92. self.mask_roi_extractor.append(MODELS.build(roi_extractor))
  93. else:
  94. self.share_roi_extractor = True
  95. self.mask_roi_extractor = self.bbox_roi_extractor
  96. def init_assigner_sampler(self) -> None:
  97. """Initialize assigner and sampler for each stage."""
  98. self.bbox_assigner = []
  99. self.bbox_sampler = []
  100. if self.train_cfg is not None:
  101. for idx, rcnn_train_cfg in enumerate(self.train_cfg):
  102. self.bbox_assigner.append(
  103. TASK_UTILS.build(rcnn_train_cfg.assigner))
  104. self.current_stage = idx
  105. self.bbox_sampler.append(
  106. TASK_UTILS.build(
  107. rcnn_train_cfg.sampler,
  108. default_args=dict(context=self)))
  109. def _bbox_forward(self, stage: int, x: Tuple[Tensor],
  110. rois: Tensor) -> dict:
  111. """Box head forward function used in both training and testing.
  112. Args:
  113. stage (int): The current stage in Cascade RoI Head.
  114. x (tuple[Tensor]): List of multi-level img features.
  115. rois (Tensor): RoIs with the shape (n, 5) where the first
  116. column indicates batch id of each RoI.
  117. Returns:
  118. dict[str, Tensor]: Usually returns a dictionary with keys:
  119. - `cls_score` (Tensor): Classification scores.
  120. - `bbox_pred` (Tensor): Box energies / deltas.
  121. - `bbox_feats` (Tensor): Extract bbox RoI features.
  122. """
  123. bbox_roi_extractor = self.bbox_roi_extractor[stage]
  124. bbox_head = self.bbox_head[stage]
  125. bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs],
  126. rois)
  127. # do not support caffe_c4 model anymore
  128. cls_score, bbox_pred = bbox_head(bbox_feats)
  129. bbox_results = dict(
  130. cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats)
  131. return bbox_results
  132. def bbox_loss(self, stage: int, x: Tuple[Tensor],
  133. sampling_results: List[SamplingResult]) -> dict:
  134. """Run forward function and calculate loss for box head in training.
  135. Args:
  136. stage (int): The current stage in Cascade RoI Head.
  137. x (tuple[Tensor]): List of multi-level img features.
  138. sampling_results (list["obj:`SamplingResult`]): Sampling results.
  139. Returns:
  140. dict: Usually returns a dictionary with keys:
  141. - `cls_score` (Tensor): Classification scores.
  142. - `bbox_pred` (Tensor): Box energies / deltas.
  143. - `bbox_feats` (Tensor): Extract bbox RoI features.
  144. - `loss_bbox` (dict): A dictionary of bbox loss components.
  145. - `rois` (Tensor): RoIs with the shape (n, 5) where the first
  146. column indicates batch id of each RoI.
  147. - `bbox_targets` (tuple): Ground truth for proposals in a
  148. single image. Containing the following list of Tensors:
  149. (labels, label_weights, bbox_targets, bbox_weights)
  150. """
  151. bbox_head = self.bbox_head[stage]
  152. rois = bbox2roi([res.priors for res in sampling_results])
  153. bbox_results = self._bbox_forward(stage, x, rois)
  154. bbox_results.update(rois=rois)
  155. bbox_loss_and_target = bbox_head.loss_and_target(
  156. cls_score=bbox_results['cls_score'],
  157. bbox_pred=bbox_results['bbox_pred'],
  158. rois=rois,
  159. sampling_results=sampling_results,
  160. rcnn_train_cfg=self.train_cfg[stage])
  161. bbox_results.update(bbox_loss_and_target)
  162. return bbox_results
  163. def _mask_forward(self, stage: int, x: Tuple[Tensor],
  164. rois: Tensor) -> dict:
  165. """Mask head forward function used in both training and testing.
  166. Args:
  167. stage (int): The current stage in Cascade RoI Head.
  168. x (tuple[Tensor]): Tuple of multi-level img features.
  169. rois (Tensor): RoIs with the shape (n, 5) where the first
  170. column indicates batch id of each RoI.
  171. Returns:
  172. dict: Usually returns a dictionary with keys:
  173. - `mask_preds` (Tensor): Mask prediction.
  174. """
  175. mask_roi_extractor = self.mask_roi_extractor[stage]
  176. mask_head = self.mask_head[stage]
  177. mask_feats = mask_roi_extractor(x[:mask_roi_extractor.num_inputs],
  178. rois)
  179. # do not support caffe_c4 model anymore
  180. mask_preds = mask_head(mask_feats)
  181. mask_results = dict(mask_preds=mask_preds)
  182. return mask_results
  183. def mask_loss(self, stage: int, x: Tuple[Tensor],
  184. sampling_results: List[SamplingResult],
  185. batch_gt_instances: InstanceList) -> dict:
  186. """Run forward function and calculate loss for mask head in training.
  187. Args:
  188. stage (int): The current stage in Cascade RoI Head.
  189. x (tuple[Tensor]): Tuple of multi-level img features.
  190. sampling_results (list["obj:`SamplingResult`]): Sampling results.
  191. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  192. gt_instance. It usually includes ``bboxes``, ``labels``, and
  193. ``masks`` attributes.
  194. Returns:
  195. dict: Usually returns a dictionary with keys:
  196. - `mask_preds` (Tensor): Mask prediction.
  197. - `loss_mask` (dict): A dictionary of mask loss components.
  198. """
  199. pos_rois = bbox2roi([res.pos_priors for res in sampling_results])
  200. mask_results = self._mask_forward(stage, x, pos_rois)
  201. mask_head = self.mask_head[stage]
  202. mask_loss_and_target = mask_head.loss_and_target(
  203. mask_preds=mask_results['mask_preds'],
  204. sampling_results=sampling_results,
  205. batch_gt_instances=batch_gt_instances,
  206. rcnn_train_cfg=self.train_cfg[stage])
  207. mask_results.update(mask_loss_and_target)
  208. return mask_results
  209. def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
  210. batch_data_samples: SampleList) -> dict:
  211. """Perform forward propagation and loss calculation of the detection
  212. roi on the features of the upstream network.
  213. Args:
  214. x (tuple[Tensor]): List of multi-level img features.
  215. rpn_results_list (list[:obj:`InstanceData`]): List of region
  216. proposals.
  217. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  218. data samples. It usually includes information such
  219. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  220. Returns:
  221. dict[str, Tensor]: A dictionary of loss components
  222. """
  223. # TODO: May add a new function in baseroihead
  224. assert len(rpn_results_list) == len(batch_data_samples)
  225. outputs = unpack_gt_instances(batch_data_samples)
  226. batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \
  227. = outputs
  228. num_imgs = len(batch_data_samples)
  229. losses = dict()
  230. results_list = rpn_results_list
  231. for stage in range(self.num_stages):
  232. self.current_stage = stage
  233. stage_loss_weight = self.stage_loss_weights[stage]
  234. # assign gts and sample proposals
  235. sampling_results = []
  236. if self.with_bbox or self.with_mask:
  237. bbox_assigner = self.bbox_assigner[stage]
  238. bbox_sampler = self.bbox_sampler[stage]
  239. for i in range(num_imgs):
  240. results = results_list[i]
  241. # rename rpn_results.bboxes to rpn_results.priors
  242. results.priors = results.pop('bboxes')
  243. assign_result = bbox_assigner.assign(
  244. results, batch_gt_instances[i],
  245. batch_gt_instances_ignore[i])
  246. sampling_result = bbox_sampler.sample(
  247. assign_result,
  248. results,
  249. batch_gt_instances[i],
  250. feats=[lvl_feat[i][None] for lvl_feat in x])
  251. sampling_results.append(sampling_result)
  252. # bbox head forward and loss
  253. bbox_results = self.bbox_loss(stage, x, sampling_results)
  254. for name, value in bbox_results['loss_bbox'].items():
  255. losses[f's{stage}.{name}'] = (
  256. value * stage_loss_weight if 'loss' in name else value)
  257. # mask head forward and loss
  258. if self.with_mask:
  259. mask_results = self.mask_loss(stage, x, sampling_results,
  260. batch_gt_instances)
  261. for name, value in mask_results['loss_mask'].items():
  262. losses[f's{stage}.{name}'] = (
  263. value * stage_loss_weight if 'loss' in name else value)
  264. # refine bboxes
  265. if stage < self.num_stages - 1:
  266. bbox_head = self.bbox_head[stage]
  267. with torch.no_grad():
  268. results_list = bbox_head.refine_bboxes(
  269. sampling_results, bbox_results, batch_img_metas)
  270. # Empty proposal
  271. if results_list is None:
  272. break
  273. return losses
  274. def predict_bbox(self,
  275. x: Tuple[Tensor],
  276. batch_img_metas: List[dict],
  277. rpn_results_list: InstanceList,
  278. rcnn_test_cfg: ConfigType,
  279. rescale: bool = False,
  280. **kwargs) -> InstanceList:
  281. """Perform forward propagation of the bbox head and predict detection
  282. results on the features of the upstream network.
  283. Args:
  284. x (tuple[Tensor]): Feature maps of all scale level.
  285. batch_img_metas (list[dict]): List of image information.
  286. rpn_results_list (list[:obj:`InstanceData`]): List of region
  287. proposals.
  288. rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
  289. rescale (bool): If True, return boxes in original image space.
  290. Defaults to False.
  291. Returns:
  292. list[:obj:`InstanceData`]: Detection results of each image
  293. after the post process.
  294. Each item usually contains following keys.
  295. - scores (Tensor): Classification scores, has a shape
  296. (num_instance, )
  297. - labels (Tensor): Labels of bboxes, has a shape
  298. (num_instances, ).
  299. - bboxes (Tensor): Has a shape (num_instances, 4),
  300. the last dimension 4 arrange as (x1, y1, x2, y2).
  301. """
  302. proposals = [res.bboxes for res in rpn_results_list]
  303. num_proposals_per_img = tuple(len(p) for p in proposals)
  304. rois = bbox2roi(proposals)
  305. if rois.shape[0] == 0:
  306. return empty_instances(
  307. batch_img_metas,
  308. rois.device,
  309. task_type='bbox',
  310. box_type=self.bbox_head[-1].predict_box_type,
  311. num_classes=self.bbox_head[-1].num_classes,
  312. score_per_cls=rcnn_test_cfg is None)
  313. rois, cls_scores, bbox_preds = self._refine_roi(
  314. x=x,
  315. rois=rois,
  316. batch_img_metas=batch_img_metas,
  317. num_proposals_per_img=num_proposals_per_img,
  318. **kwargs)
  319. results_list = self.bbox_head[-1].predict_by_feat(
  320. rois=rois,
  321. cls_scores=cls_scores,
  322. bbox_preds=bbox_preds,
  323. batch_img_metas=batch_img_metas,
  324. rescale=rescale,
  325. rcnn_test_cfg=rcnn_test_cfg)
  326. return results_list
  327. def predict_mask(self,
  328. x: Tuple[Tensor],
  329. batch_img_metas: List[dict],
  330. results_list: List[InstanceData],
  331. rescale: bool = False) -> List[InstanceData]:
  332. """Perform forward propagation of the mask head and predict detection
  333. results on the features of the upstream network.
  334. Args:
  335. x (tuple[Tensor]): Feature maps of all scale level.
  336. batch_img_metas (list[dict]): List of image information.
  337. results_list (list[:obj:`InstanceData`]): Detection results of
  338. each image.
  339. rescale (bool): If True, return boxes in original image space.
  340. Defaults to False.
  341. Returns:
  342. list[:obj:`InstanceData`]: Detection results of each image
  343. after the post process.
  344. Each item usually contains following keys.
  345. - scores (Tensor): Classification scores, has a shape
  346. (num_instance, )
  347. - labels (Tensor): Labels of bboxes, has a shape
  348. (num_instances, ).
  349. - bboxes (Tensor): Has a shape (num_instances, 4),
  350. the last dimension 4 arrange as (x1, y1, x2, y2).
  351. - masks (Tensor): Has a shape (num_instances, H, W).
  352. """
  353. bboxes = [res.bboxes for res in results_list]
  354. mask_rois = bbox2roi(bboxes)
  355. if mask_rois.shape[0] == 0:
  356. results_list = empty_instances(
  357. batch_img_metas,
  358. mask_rois.device,
  359. task_type='mask',
  360. instance_results=results_list,
  361. mask_thr_binary=self.test_cfg.mask_thr_binary)
  362. return results_list
  363. num_mask_rois_per_img = [len(res) for res in results_list]
  364. aug_masks = []
  365. for stage in range(self.num_stages):
  366. mask_results = self._mask_forward(stage, x, mask_rois)
  367. mask_preds = mask_results['mask_preds']
  368. # split batch mask prediction back to each image
  369. mask_preds = mask_preds.split(num_mask_rois_per_img, 0)
  370. aug_masks.append([m.sigmoid().detach() for m in mask_preds])
  371. merged_masks = []
  372. for i in range(len(batch_img_metas)):
  373. aug_mask = [mask[i] for mask in aug_masks]
  374. merged_mask = merge_aug_masks(aug_mask, batch_img_metas[i])
  375. merged_masks.append(merged_mask)
  376. results_list = self.mask_head[-1].predict_by_feat(
  377. mask_preds=merged_masks,
  378. results_list=results_list,
  379. batch_img_metas=batch_img_metas,
  380. rcnn_test_cfg=self.test_cfg,
  381. rescale=rescale,
  382. activate_map=True)
  383. return results_list
  384. def _refine_roi(self, x: Tuple[Tensor], rois: Tensor,
  385. batch_img_metas: List[dict],
  386. num_proposals_per_img: Sequence[int], **kwargs) -> tuple:
  387. """Multi-stage refinement of RoI.
  388. Args:
  389. x (tuple[Tensor]): List of multi-level img features.
  390. rois (Tensor): shape (n, 5), [batch_ind, x1, y1, x2, y2]
  391. batch_img_metas (list[dict]): List of image information.
  392. num_proposals_per_img (sequence[int]): number of proposals
  393. in each image.
  394. Returns:
  395. tuple:
  396. - rois (Tensor): Refined RoI.
  397. - cls_scores (list[Tensor]): Average predicted
  398. cls score per image.
  399. - bbox_preds (list[Tensor]): Bbox branch predictions
  400. for the last stage of per image.
  401. """
  402. # "ms" in variable names means multi-stage
  403. ms_scores = []
  404. for stage in range(self.num_stages):
  405. bbox_results = self._bbox_forward(
  406. stage=stage, x=x, rois=rois, **kwargs)
  407. # split batch bbox prediction back to each image
  408. cls_scores = bbox_results['cls_score']
  409. bbox_preds = bbox_results['bbox_pred']
  410. rois = rois.split(num_proposals_per_img, 0)
  411. cls_scores = cls_scores.split(num_proposals_per_img, 0)
  412. ms_scores.append(cls_scores)
  413. # some detector with_reg is False, bbox_preds will be None
  414. if bbox_preds is not None:
  415. # TODO move this to a sabl_roi_head
  416. # the bbox prediction of some detectors like SABL is not Tensor
  417. if isinstance(bbox_preds, torch.Tensor):
  418. bbox_preds = bbox_preds.split(num_proposals_per_img, 0)
  419. else:
  420. bbox_preds = self.bbox_head[stage].bbox_pred_split(
  421. bbox_preds, num_proposals_per_img)
  422. else:
  423. bbox_preds = (None, ) * len(batch_img_metas)
  424. if stage < self.num_stages - 1:
  425. bbox_head = self.bbox_head[stage]
  426. if bbox_head.custom_activation:
  427. cls_scores = [
  428. bbox_head.loss_cls.get_activation(s)
  429. for s in cls_scores
  430. ]
  431. refine_rois_list = []
  432. for i in range(len(batch_img_metas)):
  433. if rois[i].shape[0] > 0:
  434. bbox_label = cls_scores[i][:, :-1].argmax(dim=1)
  435. # Refactor `bbox_head.regress_by_class` to only accept
  436. # box tensor without img_idx concatenated.
  437. refined_bboxes = bbox_head.regress_by_class(
  438. rois[i][:, 1:], bbox_label, bbox_preds[i],
  439. batch_img_metas[i])
  440. refined_bboxes = get_box_tensor(refined_bboxes)
  441. refined_rois = torch.cat(
  442. [rois[i][:, [0]], refined_bboxes], dim=1)
  443. refine_rois_list.append(refined_rois)
  444. rois = torch.cat(refine_rois_list)
  445. # average scores of each image by stages
  446. cls_scores = [
  447. sum([score[i] for score in ms_scores]) / float(len(ms_scores))
  448. for i in range(len(batch_img_metas))
  449. ]
  450. return rois, cls_scores, bbox_preds
  451. def forward(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
  452. batch_data_samples: SampleList) -> tuple:
  453. """Network forward process. Usually includes backbone, neck and head
  454. forward without any post-processing.
  455. Args:
  456. x (List[Tensor]): Multi-level features that may have different
  457. resolutions.
  458. rpn_results_list (list[:obj:`InstanceData`]): List of region
  459. proposals.
  460. batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
  461. the meta information of each image and corresponding
  462. annotations.
  463. Returns
  464. tuple: A tuple of features from ``bbox_head`` and ``mask_head``
  465. forward.
  466. """
  467. results = ()
  468. batch_img_metas = [
  469. data_samples.metainfo for data_samples in batch_data_samples
  470. ]
  471. proposals = [rpn_results.bboxes for rpn_results in rpn_results_list]
  472. num_proposals_per_img = tuple(len(p) for p in proposals)
  473. rois = bbox2roi(proposals)
  474. # bbox head
  475. if self.with_bbox:
  476. rois, cls_scores, bbox_preds = self._refine_roi(
  477. x, rois, batch_img_metas, num_proposals_per_img)
  478. results = results + (cls_scores, bbox_preds)
  479. # mask head
  480. if self.with_mask:
  481. aug_masks = []
  482. rois = torch.cat(rois)
  483. for stage in range(self.num_stages):
  484. mask_results = self._mask_forward(stage, x, rois)
  485. mask_preds = mask_results['mask_preds']
  486. mask_preds = mask_preds.split(num_proposals_per_img, 0)
  487. aug_masks.append([m.sigmoid().detach() for m in mask_preds])
  488. merged_masks = []
  489. for i in range(len(batch_img_metas)):
  490. aug_mask = [mask[i] for mask in aug_masks]
  491. merged_mask = merge_aug_masks(aug_mask, batch_img_metas[i])
  492. merged_masks.append(merged_mask)
  493. results = results + (merged_masks, )
  494. return results