structures.py 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import itertools
  3. from abc import ABCMeta, abstractmethod
  4. from typing import Sequence, Type, TypeVar
  5. import cv2
  6. import mmcv
  7. import numpy as np
  8. import pycocotools.mask as maskUtils
  9. import shapely.geometry as geometry
  10. import torch
  11. from mmcv.ops.roi_align import roi_align
  12. T = TypeVar('T')
  13. class BaseInstanceMasks(metaclass=ABCMeta):
  14. """Base class for instance masks."""
  15. @abstractmethod
  16. def rescale(self, scale, interpolation='nearest'):
  17. """Rescale masks as large as possible while keeping the aspect ratio.
  18. For details can refer to `mmcv.imrescale`.
  19. Args:
  20. scale (tuple[int]): The maximum size (h, w) of rescaled mask.
  21. interpolation (str): Same as :func:`mmcv.imrescale`.
  22. Returns:
  23. BaseInstanceMasks: The rescaled masks.
  24. """
  25. @abstractmethod
  26. def resize(self, out_shape, interpolation='nearest'):
  27. """Resize masks to the given out_shape.
  28. Args:
  29. out_shape: Target (h, w) of resized mask.
  30. interpolation (str): See :func:`mmcv.imresize`.
  31. Returns:
  32. BaseInstanceMasks: The resized masks.
  33. """
  34. @abstractmethod
  35. def flip(self, flip_direction='horizontal'):
  36. """Flip masks alone the given direction.
  37. Args:
  38. flip_direction (str): Either 'horizontal' or 'vertical'.
  39. Returns:
  40. BaseInstanceMasks: The flipped masks.
  41. """
  42. @abstractmethod
  43. def pad(self, out_shape, pad_val):
  44. """Pad masks to the given size of (h, w).
  45. Args:
  46. out_shape (tuple[int]): Target (h, w) of padded mask.
  47. pad_val (int): The padded value.
  48. Returns:
  49. BaseInstanceMasks: The padded masks.
  50. """
  51. @abstractmethod
  52. def crop(self, bbox):
  53. """Crop each mask by the given bbox.
  54. Args:
  55. bbox (ndarray): Bbox in format [x1, y1, x2, y2], shape (4, ).
  56. Return:
  57. BaseInstanceMasks: The cropped masks.
  58. """
  59. @abstractmethod
  60. def crop_and_resize(self,
  61. bboxes,
  62. out_shape,
  63. inds,
  64. device,
  65. interpolation='bilinear',
  66. binarize=True):
  67. """Crop and resize masks by the given bboxes.
  68. This function is mainly used in mask targets computation.
  69. It firstly align mask to bboxes by assigned_inds, then crop mask by the
  70. assigned bbox and resize to the size of (mask_h, mask_w)
  71. Args:
  72. bboxes (Tensor): Bboxes in format [x1, y1, x2, y2], shape (N, 4)
  73. out_shape (tuple[int]): Target (h, w) of resized mask
  74. inds (ndarray): Indexes to assign masks to each bbox,
  75. shape (N,) and values should be between [0, num_masks - 1].
  76. device (str): Device of bboxes
  77. interpolation (str): See `mmcv.imresize`
  78. binarize (bool): if True fractional values are rounded to 0 or 1
  79. after the resize operation. if False and unsupported an error
  80. will be raised. Defaults to True.
  81. Return:
  82. BaseInstanceMasks: the cropped and resized masks.
  83. """
  84. @abstractmethod
  85. def expand(self, expanded_h, expanded_w, top, left):
  86. """see :class:`Expand`."""
  87. @property
  88. @abstractmethod
  89. def areas(self):
  90. """ndarray: areas of each instance."""
  91. @abstractmethod
  92. def to_ndarray(self):
  93. """Convert masks to the format of ndarray.
  94. Return:
  95. ndarray: Converted masks in the format of ndarray.
  96. """
  97. @abstractmethod
  98. def to_tensor(self, dtype, device):
  99. """Convert masks to the format of Tensor.
  100. Args:
  101. dtype (str): Dtype of converted mask.
  102. device (torch.device): Device of converted masks.
  103. Returns:
  104. Tensor: Converted masks in the format of Tensor.
  105. """
  106. @abstractmethod
  107. def translate(self,
  108. out_shape,
  109. offset,
  110. direction='horizontal',
  111. border_value=0,
  112. interpolation='bilinear'):
  113. """Translate the masks.
  114. Args:
  115. out_shape (tuple[int]): Shape for output mask, format (h, w).
  116. offset (int | float): The offset for translate.
  117. direction (str): The translate direction, either "horizontal"
  118. or "vertical".
  119. border_value (int | float): Border value. Default 0.
  120. interpolation (str): Same as :func:`mmcv.imtranslate`.
  121. Returns:
  122. Translated masks.
  123. """
  124. def shear(self,
  125. out_shape,
  126. magnitude,
  127. direction='horizontal',
  128. border_value=0,
  129. interpolation='bilinear'):
  130. """Shear the masks.
  131. Args:
  132. out_shape (tuple[int]): Shape for output mask, format (h, w).
  133. magnitude (int | float): The magnitude used for shear.
  134. direction (str): The shear direction, either "horizontal"
  135. or "vertical".
  136. border_value (int | tuple[int]): Value used in case of a
  137. constant border. Default 0.
  138. interpolation (str): Same as in :func:`mmcv.imshear`.
  139. Returns:
  140. ndarray: Sheared masks.
  141. """
  142. @abstractmethod
  143. def rotate(self, out_shape, angle, center=None, scale=1.0, border_value=0):
  144. """Rotate the masks.
  145. Args:
  146. out_shape (tuple[int]): Shape for output mask, format (h, w).
  147. angle (int | float): Rotation angle in degrees. Positive values
  148. mean counter-clockwise rotation.
  149. center (tuple[float], optional): Center point (w, h) of the
  150. rotation in source image. If not specified, the center of
  151. the image will be used.
  152. scale (int | float): Isotropic scale factor.
  153. border_value (int | float): Border value. Default 0 for masks.
  154. Returns:
  155. Rotated masks.
  156. """
  157. def get_bboxes(self, dst_type='hbb'):
  158. """Get the certain type boxes from masks.
  159. Please refer to ``mmdet.structures.bbox.box_type`` for more details of
  160. the box type.
  161. Args:
  162. dst_type: Destination box type.
  163. Returns:
  164. :obj:`BaseBoxes`: Certain type boxes.
  165. """
  166. from ..bbox import get_box_type
  167. _, box_type_cls = get_box_type(dst_type)
  168. return box_type_cls.from_instance_masks(self)
  169. @classmethod
  170. @abstractmethod
  171. def cat(cls: Type[T], masks: Sequence[T]) -> T:
  172. """Concatenate a sequence of masks into one single mask instance.
  173. Args:
  174. masks (Sequence[T]): A sequence of mask instances.
  175. Returns:
  176. T: Concatenated mask instance.
  177. """
  178. class BitmapMasks(BaseInstanceMasks):
  179. """This class represents masks in the form of bitmaps.
  180. Args:
  181. masks (ndarray): ndarray of masks in shape (N, H, W), where N is
  182. the number of objects.
  183. height (int): height of masks
  184. width (int): width of masks
  185. Example:
  186. >>> from mmdet.data_elements.mask.structures import * # NOQA
  187. >>> num_masks, H, W = 3, 32, 32
  188. >>> rng = np.random.RandomState(0)
  189. >>> masks = (rng.rand(num_masks, H, W) > 0.1).astype(np.int64)
  190. >>> self = BitmapMasks(masks, height=H, width=W)
  191. >>> # demo crop_and_resize
  192. >>> num_boxes = 5
  193. >>> bboxes = np.array([[0, 0, 30, 10.0]] * num_boxes)
  194. >>> out_shape = (14, 14)
  195. >>> inds = torch.randint(0, len(self), size=(num_boxes,))
  196. >>> device = 'cpu'
  197. >>> interpolation = 'bilinear'
  198. >>> new = self.crop_and_resize(
  199. ... bboxes, out_shape, inds, device, interpolation)
  200. >>> assert len(new) == num_boxes
  201. >>> assert new.height, new.width == out_shape
  202. """
  203. def __init__(self, masks, height, width):
  204. self.height = height
  205. self.width = width
  206. if len(masks) == 0:
  207. self.masks = np.empty((0, self.height, self.width), dtype=np.uint8)
  208. else:
  209. assert isinstance(masks, (list, np.ndarray))
  210. if isinstance(masks, list):
  211. assert isinstance(masks[0], np.ndarray)
  212. assert masks[0].ndim == 2 # (H, W)
  213. else:
  214. assert masks.ndim == 3 # (N, H, W)
  215. self.masks = np.stack(masks).reshape(-1, height, width)
  216. assert self.masks.shape[1] == self.height
  217. assert self.masks.shape[2] == self.width
  218. def __getitem__(self, index):
  219. """Index the BitmapMask.
  220. Args:
  221. index (int | ndarray): Indices in the format of integer or ndarray.
  222. Returns:
  223. :obj:`BitmapMasks`: Indexed bitmap masks.
  224. """
  225. masks = self.masks[index].reshape(-1, self.height, self.width)
  226. return BitmapMasks(masks, self.height, self.width)
  227. def __iter__(self):
  228. return iter(self.masks)
  229. def __repr__(self):
  230. s = self.__class__.__name__ + '('
  231. s += f'num_masks={len(self.masks)}, '
  232. s += f'height={self.height}, '
  233. s += f'width={self.width})'
  234. return s
  235. def __len__(self):
  236. """Number of masks."""
  237. return len(self.masks)
  238. def rescale(self, scale, interpolation='nearest'):
  239. """See :func:`BaseInstanceMasks.rescale`."""
  240. if len(self.masks) == 0:
  241. new_w, new_h = mmcv.rescale_size((self.width, self.height), scale)
  242. rescaled_masks = np.empty((0, new_h, new_w), dtype=np.uint8)
  243. else:
  244. rescaled_masks = np.stack([
  245. mmcv.imrescale(mask, scale, interpolation=interpolation)
  246. for mask in self.masks
  247. ])
  248. height, width = rescaled_masks.shape[1:]
  249. return BitmapMasks(rescaled_masks, height, width)
  250. def resize(self, out_shape, interpolation='nearest'):
  251. """See :func:`BaseInstanceMasks.resize`."""
  252. if len(self.masks) == 0:
  253. resized_masks = np.empty((0, *out_shape), dtype=np.uint8)
  254. else:
  255. resized_masks = np.stack([
  256. mmcv.imresize(
  257. mask, out_shape[::-1], interpolation=interpolation)
  258. for mask in self.masks
  259. ])
  260. return BitmapMasks(resized_masks, *out_shape)
  261. def flip(self, flip_direction='horizontal'):
  262. """See :func:`BaseInstanceMasks.flip`."""
  263. assert flip_direction in ('horizontal', 'vertical', 'diagonal')
  264. if len(self.masks) == 0:
  265. flipped_masks = self.masks
  266. else:
  267. flipped_masks = np.stack([
  268. mmcv.imflip(mask, direction=flip_direction)
  269. for mask in self.masks
  270. ])
  271. return BitmapMasks(flipped_masks, self.height, self.width)
  272. def pad(self, out_shape, pad_val=0):
  273. """See :func:`BaseInstanceMasks.pad`."""
  274. if len(self.masks) == 0:
  275. padded_masks = np.empty((0, *out_shape), dtype=np.uint8)
  276. else:
  277. padded_masks = np.stack([
  278. mmcv.impad(mask, shape=out_shape, pad_val=pad_val)
  279. for mask in self.masks
  280. ])
  281. return BitmapMasks(padded_masks, *out_shape)
  282. def crop(self, bbox):
  283. """See :func:`BaseInstanceMasks.crop`."""
  284. assert isinstance(bbox, np.ndarray)
  285. assert bbox.ndim == 1
  286. # clip the boundary
  287. bbox = bbox.copy()
  288. bbox[0::2] = np.clip(bbox[0::2], 0, self.width)
  289. bbox[1::2] = np.clip(bbox[1::2], 0, self.height)
  290. x1, y1, x2, y2 = bbox
  291. w = np.maximum(x2 - x1, 1)
  292. h = np.maximum(y2 - y1, 1)
  293. if len(self.masks) == 0:
  294. cropped_masks = np.empty((0, h, w), dtype=np.uint8)
  295. else:
  296. cropped_masks = self.masks[:, y1:y1 + h, x1:x1 + w]
  297. return BitmapMasks(cropped_masks, h, w)
  298. def crop_and_resize(self,
  299. bboxes,
  300. out_shape,
  301. inds,
  302. device='cpu',
  303. interpolation='bilinear',
  304. binarize=True):
  305. """See :func:`BaseInstanceMasks.crop_and_resize`."""
  306. if len(self.masks) == 0:
  307. empty_masks = np.empty((0, *out_shape), dtype=np.uint8)
  308. return BitmapMasks(empty_masks, *out_shape)
  309. # convert bboxes to tensor
  310. if isinstance(bboxes, np.ndarray):
  311. bboxes = torch.from_numpy(bboxes).to(device=device)
  312. if isinstance(inds, np.ndarray):
  313. inds = torch.from_numpy(inds).to(device=device)
  314. num_bbox = bboxes.shape[0]
  315. fake_inds = torch.arange(
  316. num_bbox, device=device).to(dtype=bboxes.dtype)[:, None]
  317. rois = torch.cat([fake_inds, bboxes], dim=1) # Nx5
  318. rois = rois.to(device=device)
  319. if num_bbox > 0:
  320. gt_masks_th = torch.from_numpy(self.masks).to(device).index_select(
  321. 0, inds).to(dtype=rois.dtype)
  322. targets = roi_align(gt_masks_th[:, None, :, :], rois, out_shape,
  323. 1.0, 0, 'avg', True).squeeze(1)
  324. if binarize:
  325. resized_masks = (targets >= 0.5).cpu().numpy()
  326. else:
  327. resized_masks = targets.cpu().numpy()
  328. else:
  329. resized_masks = []
  330. return BitmapMasks(resized_masks, *out_shape)
  331. def expand(self, expanded_h, expanded_w, top, left):
  332. """See :func:`BaseInstanceMasks.expand`."""
  333. if len(self.masks) == 0:
  334. expanded_mask = np.empty((0, expanded_h, expanded_w),
  335. dtype=np.uint8)
  336. else:
  337. expanded_mask = np.zeros((len(self), expanded_h, expanded_w),
  338. dtype=np.uint8)
  339. expanded_mask[:, top:top + self.height,
  340. left:left + self.width] = self.masks
  341. return BitmapMasks(expanded_mask, expanded_h, expanded_w)
  342. def translate(self,
  343. out_shape,
  344. offset,
  345. direction='horizontal',
  346. border_value=0,
  347. interpolation='bilinear'):
  348. """Translate the BitmapMasks.
  349. Args:
  350. out_shape (tuple[int]): Shape for output mask, format (h, w).
  351. offset (int | float): The offset for translate.
  352. direction (str): The translate direction, either "horizontal"
  353. or "vertical".
  354. border_value (int | float): Border value. Default 0 for masks.
  355. interpolation (str): Same as :func:`mmcv.imtranslate`.
  356. Returns:
  357. BitmapMasks: Translated BitmapMasks.
  358. Example:
  359. >>> from mmdet.data_elements.mask.structures import BitmapMasks
  360. >>> self = BitmapMasks.random(dtype=np.uint8)
  361. >>> out_shape = (32, 32)
  362. >>> offset = 4
  363. >>> direction = 'horizontal'
  364. >>> border_value = 0
  365. >>> interpolation = 'bilinear'
  366. >>> # Note, There seem to be issues when:
  367. >>> # * the mask dtype is not supported by cv2.AffineWarp
  368. >>> new = self.translate(out_shape, offset, direction,
  369. >>> border_value, interpolation)
  370. >>> assert len(new) == len(self)
  371. >>> assert new.height, new.width == out_shape
  372. """
  373. if len(self.masks) == 0:
  374. translated_masks = np.empty((0, *out_shape), dtype=np.uint8)
  375. else:
  376. masks = self.masks
  377. if masks.shape[-2:] != out_shape:
  378. empty_masks = np.zeros((masks.shape[0], *out_shape),
  379. dtype=masks.dtype)
  380. min_h = min(out_shape[0], masks.shape[1])
  381. min_w = min(out_shape[1], masks.shape[2])
  382. empty_masks[:, :min_h, :min_w] = masks[:, :min_h, :min_w]
  383. masks = empty_masks
  384. translated_masks = mmcv.imtranslate(
  385. masks.transpose((1, 2, 0)),
  386. offset,
  387. direction,
  388. border_value=border_value,
  389. interpolation=interpolation)
  390. if translated_masks.ndim == 2:
  391. translated_masks = translated_masks[:, :, None]
  392. translated_masks = translated_masks.transpose(
  393. (2, 0, 1)).astype(self.masks.dtype)
  394. return BitmapMasks(translated_masks, *out_shape)
  395. def shear(self,
  396. out_shape,
  397. magnitude,
  398. direction='horizontal',
  399. border_value=0,
  400. interpolation='bilinear'):
  401. """Shear the BitmapMasks.
  402. Args:
  403. out_shape (tuple[int]): Shape for output mask, format (h, w).
  404. magnitude (int | float): The magnitude used for shear.
  405. direction (str): The shear direction, either "horizontal"
  406. or "vertical".
  407. border_value (int | tuple[int]): Value used in case of a
  408. constant border.
  409. interpolation (str): Same as in :func:`mmcv.imshear`.
  410. Returns:
  411. BitmapMasks: The sheared masks.
  412. """
  413. if len(self.masks) == 0:
  414. sheared_masks = np.empty((0, *out_shape), dtype=np.uint8)
  415. else:
  416. sheared_masks = mmcv.imshear(
  417. self.masks.transpose((1, 2, 0)),
  418. magnitude,
  419. direction,
  420. border_value=border_value,
  421. interpolation=interpolation)
  422. if sheared_masks.ndim == 2:
  423. sheared_masks = sheared_masks[:, :, None]
  424. sheared_masks = sheared_masks.transpose(
  425. (2, 0, 1)).astype(self.masks.dtype)
  426. return BitmapMasks(sheared_masks, *out_shape)
  427. def rotate(self,
  428. out_shape,
  429. angle,
  430. center=None,
  431. scale=1.0,
  432. border_value=0,
  433. interpolation='bilinear'):
  434. """Rotate the BitmapMasks.
  435. Args:
  436. out_shape (tuple[int]): Shape for output mask, format (h, w).
  437. angle (int | float): Rotation angle in degrees. Positive values
  438. mean counter-clockwise rotation.
  439. center (tuple[float], optional): Center point (w, h) of the
  440. rotation in source image. If not specified, the center of
  441. the image will be used.
  442. scale (int | float): Isotropic scale factor.
  443. border_value (int | float): Border value. Default 0 for masks.
  444. interpolation (str): Same as in :func:`mmcv.imrotate`.
  445. Returns:
  446. BitmapMasks: Rotated BitmapMasks.
  447. """
  448. if len(self.masks) == 0:
  449. rotated_masks = np.empty((0, *out_shape), dtype=self.masks.dtype)
  450. else:
  451. rotated_masks = mmcv.imrotate(
  452. self.masks.transpose((1, 2, 0)),
  453. angle,
  454. center=center,
  455. scale=scale,
  456. border_value=border_value,
  457. interpolation=interpolation)
  458. if rotated_masks.ndim == 2:
  459. # case when only one mask, (h, w)
  460. rotated_masks = rotated_masks[:, :, None] # (h, w, 1)
  461. rotated_masks = rotated_masks.transpose(
  462. (2, 0, 1)).astype(self.masks.dtype)
  463. return BitmapMasks(rotated_masks, *out_shape)
  464. @property
  465. def areas(self):
  466. """See :py:attr:`BaseInstanceMasks.areas`."""
  467. return self.masks.sum((1, 2))
  468. def to_ndarray(self):
  469. """See :func:`BaseInstanceMasks.to_ndarray`."""
  470. return self.masks
  471. def to_tensor(self, dtype, device):
  472. """See :func:`BaseInstanceMasks.to_tensor`."""
  473. return torch.tensor(self.masks, dtype=dtype, device=device)
  474. @classmethod
  475. def random(cls,
  476. num_masks=3,
  477. height=32,
  478. width=32,
  479. dtype=np.uint8,
  480. rng=None):
  481. """Generate random bitmap masks for demo / testing purposes.
  482. Example:
  483. >>> from mmdet.data_elements.mask.structures import BitmapMasks
  484. >>> self = BitmapMasks.random()
  485. >>> print('self = {}'.format(self))
  486. self = BitmapMasks(num_masks=3, height=32, width=32)
  487. """
  488. from mmdet.utils.util_random import ensure_rng
  489. rng = ensure_rng(rng)
  490. masks = (rng.rand(num_masks, height, width) > 0.1).astype(dtype)
  491. self = cls(masks, height=height, width=width)
  492. return self
  493. @classmethod
  494. def cat(cls: Type[T], masks: Sequence[T]) -> T:
  495. """Concatenate a sequence of masks into one single mask instance.
  496. Args:
  497. masks (Sequence[BitmapMasks]): A sequence of mask instances.
  498. Returns:
  499. BitmapMasks: Concatenated mask instance.
  500. """
  501. assert isinstance(masks, Sequence)
  502. if len(masks) == 0:
  503. raise ValueError('masks should not be an empty list.')
  504. assert all(isinstance(m, cls) for m in masks)
  505. mask_array = np.concatenate([m.masks for m in masks], axis=0)
  506. return cls(mask_array, *mask_array.shape[1:])
  507. class PolygonMasks(BaseInstanceMasks):
  508. """This class represents masks in the form of polygons.
  509. Polygons is a list of three levels. The first level of the list
  510. corresponds to objects, the second level to the polys that compose the
  511. object, the third level to the poly coordinates
  512. Args:
  513. masks (list[list[ndarray]]): The first level of the list
  514. corresponds to objects, the second level to the polys that
  515. compose the object, the third level to the poly coordinates
  516. height (int): height of masks
  517. width (int): width of masks
  518. Example:
  519. >>> from mmdet.data_elements.mask.structures import * # NOQA
  520. >>> masks = [
  521. >>> [ np.array([0, 0, 10, 0, 10, 10., 0, 10, 0, 0]) ]
  522. >>> ]
  523. >>> height, width = 16, 16
  524. >>> self = PolygonMasks(masks, height, width)
  525. >>> # demo translate
  526. >>> new = self.translate((16, 16), 4., direction='horizontal')
  527. >>> assert np.all(new.masks[0][0][1::2] == masks[0][0][1::2])
  528. >>> assert np.all(new.masks[0][0][0::2] == masks[0][0][0::2] + 4)
  529. >>> # demo crop_and_resize
  530. >>> num_boxes = 3
  531. >>> bboxes = np.array([[0, 0, 30, 10.0]] * num_boxes)
  532. >>> out_shape = (16, 16)
  533. >>> inds = torch.randint(0, len(self), size=(num_boxes,))
  534. >>> device = 'cpu'
  535. >>> interpolation = 'bilinear'
  536. >>> new = self.crop_and_resize(
  537. ... bboxes, out_shape, inds, device, interpolation)
  538. >>> assert len(new) == num_boxes
  539. >>> assert new.height, new.width == out_shape
  540. """
  541. def __init__(self, masks, height, width):
  542. assert isinstance(masks, list)
  543. if len(masks) > 0:
  544. assert isinstance(masks[0], list)
  545. assert isinstance(masks[0][0], np.ndarray)
  546. self.height = height
  547. self.width = width
  548. self.masks = masks
  549. def __getitem__(self, index):
  550. """Index the polygon masks.
  551. Args:
  552. index (ndarray | List): The indices.
  553. Returns:
  554. :obj:`PolygonMasks`: The indexed polygon masks.
  555. """
  556. if isinstance(index, np.ndarray):
  557. if index.dtype == bool:
  558. index = np.where(index)[0].tolist()
  559. else:
  560. index = index.tolist()
  561. if isinstance(index, list):
  562. masks = [self.masks[i] for i in index]
  563. else:
  564. try:
  565. masks = self.masks[index]
  566. except Exception:
  567. raise ValueError(
  568. f'Unsupported input of type {type(index)} for indexing!')
  569. if len(masks) and isinstance(masks[0], np.ndarray):
  570. masks = [masks] # ensure a list of three levels
  571. return PolygonMasks(masks, self.height, self.width)
  572. def __iter__(self):
  573. return iter(self.masks)
  574. def __repr__(self):
  575. s = self.__class__.__name__ + '('
  576. s += f'num_masks={len(self.masks)}, '
  577. s += f'height={self.height}, '
  578. s += f'width={self.width})'
  579. return s
  580. def __len__(self):
  581. """Number of masks."""
  582. return len(self.masks)
  583. def rescale(self, scale, interpolation=None):
  584. """see :func:`BaseInstanceMasks.rescale`"""
  585. new_w, new_h = mmcv.rescale_size((self.width, self.height), scale)
  586. if len(self.masks) == 0:
  587. rescaled_masks = PolygonMasks([], new_h, new_w)
  588. else:
  589. rescaled_masks = self.resize((new_h, new_w))
  590. return rescaled_masks
  591. def resize(self, out_shape, interpolation=None):
  592. """see :func:`BaseInstanceMasks.resize`"""
  593. if len(self.masks) == 0:
  594. resized_masks = PolygonMasks([], *out_shape)
  595. else:
  596. h_scale = out_shape[0] / self.height
  597. w_scale = out_shape[1] / self.width
  598. resized_masks = []
  599. for poly_per_obj in self.masks:
  600. resized_poly = []
  601. for p in poly_per_obj:
  602. p = p.copy()
  603. p[0::2] = p[0::2] * w_scale
  604. p[1::2] = p[1::2] * h_scale
  605. resized_poly.append(p)
  606. resized_masks.append(resized_poly)
  607. resized_masks = PolygonMasks(resized_masks, *out_shape)
  608. return resized_masks
  609. def flip(self, flip_direction='horizontal'):
  610. """see :func:`BaseInstanceMasks.flip`"""
  611. assert flip_direction in ('horizontal', 'vertical', 'diagonal')
  612. if len(self.masks) == 0:
  613. flipped_masks = PolygonMasks([], self.height, self.width)
  614. else:
  615. flipped_masks = []
  616. for poly_per_obj in self.masks:
  617. flipped_poly_per_obj = []
  618. for p in poly_per_obj:
  619. p = p.copy()
  620. if flip_direction == 'horizontal':
  621. p[0::2] = self.width - p[0::2]
  622. elif flip_direction == 'vertical':
  623. p[1::2] = self.height - p[1::2]
  624. else:
  625. p[0::2] = self.width - p[0::2]
  626. p[1::2] = self.height - p[1::2]
  627. flipped_poly_per_obj.append(p)
  628. flipped_masks.append(flipped_poly_per_obj)
  629. flipped_masks = PolygonMasks(flipped_masks, self.height,
  630. self.width)
  631. return flipped_masks
  632. def crop(self, bbox):
  633. """see :func:`BaseInstanceMasks.crop`"""
  634. assert isinstance(bbox, np.ndarray)
  635. assert bbox.ndim == 1
  636. # clip the boundary
  637. bbox = bbox.copy()
  638. bbox[0::2] = np.clip(bbox[0::2], 0, self.width)
  639. bbox[1::2] = np.clip(bbox[1::2], 0, self.height)
  640. x1, y1, x2, y2 = bbox
  641. w = np.maximum(x2 - x1, 1)
  642. h = np.maximum(y2 - y1, 1)
  643. if len(self.masks) == 0:
  644. cropped_masks = PolygonMasks([], h, w)
  645. else:
  646. # reference: https://github.com/facebookresearch/fvcore/blob/main/fvcore/transforms/transform.py # noqa
  647. crop_box = geometry.box(x1, y1, x2, y2).buffer(0.0)
  648. cropped_masks = []
  649. # suppress shapely warnings util it incorporates GEOS>=3.11.2
  650. # reference: https://github.com/shapely/shapely/issues/1345
  651. initial_settings = np.seterr()
  652. np.seterr(invalid='ignore')
  653. for poly_per_obj in self.masks:
  654. cropped_poly_per_obj = []
  655. for p in poly_per_obj:
  656. p = p.copy()
  657. p = geometry.Polygon(p.reshape(-1, 2)).buffer(0.0)
  658. # polygon must be valid to perform intersection.
  659. if not p.is_valid:
  660. continue
  661. cropped = p.intersection(crop_box)
  662. if cropped.is_empty:
  663. continue
  664. if isinstance(cropped,
  665. geometry.collection.BaseMultipartGeometry):
  666. cropped = cropped.geoms
  667. else:
  668. cropped = [cropped]
  669. # one polygon may be cropped to multiple ones
  670. for poly in cropped:
  671. # ignore lines or points
  672. if not isinstance(
  673. poly, geometry.Polygon) or not poly.is_valid:
  674. continue
  675. coords = np.asarray(poly.exterior.coords)
  676. # remove an extra identical vertex at the end
  677. coords = coords[:-1]
  678. coords[:, 0] -= x1
  679. coords[:, 1] -= y1
  680. cropped_poly_per_obj.append(coords.reshape(-1))
  681. # a dummy polygon to avoid misalignment between masks and boxes
  682. if len(cropped_poly_per_obj) == 0:
  683. cropped_poly_per_obj = [np.array([0, 0, 0, 0, 0, 0])]
  684. cropped_masks.append(cropped_poly_per_obj)
  685. np.seterr(**initial_settings)
  686. cropped_masks = PolygonMasks(cropped_masks, h, w)
  687. return cropped_masks
  688. def pad(self, out_shape, pad_val=0):
  689. """padding has no effect on polygons`"""
  690. return PolygonMasks(self.masks, *out_shape)
  691. def expand(self, *args, **kwargs):
  692. """TODO: Add expand for polygon"""
  693. raise NotImplementedError
  694. def crop_and_resize(self,
  695. bboxes,
  696. out_shape,
  697. inds,
  698. device='cpu',
  699. interpolation='bilinear',
  700. binarize=True):
  701. """see :func:`BaseInstanceMasks.crop_and_resize`"""
  702. out_h, out_w = out_shape
  703. if len(self.masks) == 0:
  704. return PolygonMasks([], out_h, out_w)
  705. if not binarize:
  706. raise ValueError('Polygons are always binary, '
  707. 'setting binarize=False is unsupported')
  708. resized_masks = []
  709. for i in range(len(bboxes)):
  710. mask = self.masks[inds[i]]
  711. bbox = bboxes[i, :]
  712. x1, y1, x2, y2 = bbox
  713. w = np.maximum(x2 - x1, 1)
  714. h = np.maximum(y2 - y1, 1)
  715. h_scale = out_h / max(h, 0.1) # avoid too large scale
  716. w_scale = out_w / max(w, 0.1)
  717. resized_mask = []
  718. for p in mask:
  719. p = p.copy()
  720. # crop
  721. # pycocotools will clip the boundary
  722. p[0::2] = p[0::2] - bbox[0]
  723. p[1::2] = p[1::2] - bbox[1]
  724. # resize
  725. p[0::2] = p[0::2] * w_scale
  726. p[1::2] = p[1::2] * h_scale
  727. resized_mask.append(p)
  728. resized_masks.append(resized_mask)
  729. return PolygonMasks(resized_masks, *out_shape)
  730. def translate(self,
  731. out_shape,
  732. offset,
  733. direction='horizontal',
  734. border_value=None,
  735. interpolation=None):
  736. """Translate the PolygonMasks.
  737. Example:
  738. >>> self = PolygonMasks.random(dtype=np.int64)
  739. >>> out_shape = (self.height, self.width)
  740. >>> new = self.translate(out_shape, 4., direction='horizontal')
  741. >>> assert np.all(new.masks[0][0][1::2] == self.masks[0][0][1::2])
  742. >>> assert np.all(new.masks[0][0][0::2] == self.masks[0][0][0::2] + 4) # noqa: E501
  743. """
  744. assert border_value is None or border_value == 0, \
  745. 'Here border_value is not '\
  746. f'used, and defaultly should be None or 0. got {border_value}.'
  747. if len(self.masks) == 0:
  748. translated_masks = PolygonMasks([], *out_shape)
  749. else:
  750. translated_masks = []
  751. for poly_per_obj in self.masks:
  752. translated_poly_per_obj = []
  753. for p in poly_per_obj:
  754. p = p.copy()
  755. if direction == 'horizontal':
  756. p[0::2] = np.clip(p[0::2] + offset, 0, out_shape[1])
  757. elif direction == 'vertical':
  758. p[1::2] = np.clip(p[1::2] + offset, 0, out_shape[0])
  759. translated_poly_per_obj.append(p)
  760. translated_masks.append(translated_poly_per_obj)
  761. translated_masks = PolygonMasks(translated_masks, *out_shape)
  762. return translated_masks
  763. def shear(self,
  764. out_shape,
  765. magnitude,
  766. direction='horizontal',
  767. border_value=0,
  768. interpolation='bilinear'):
  769. """See :func:`BaseInstanceMasks.shear`."""
  770. if len(self.masks) == 0:
  771. sheared_masks = PolygonMasks([], *out_shape)
  772. else:
  773. sheared_masks = []
  774. if direction == 'horizontal':
  775. shear_matrix = np.stack([[1, magnitude],
  776. [0, 1]]).astype(np.float32)
  777. elif direction == 'vertical':
  778. shear_matrix = np.stack([[1, 0], [magnitude,
  779. 1]]).astype(np.float32)
  780. for poly_per_obj in self.masks:
  781. sheared_poly = []
  782. for p in poly_per_obj:
  783. p = np.stack([p[0::2], p[1::2]], axis=0) # [2, n]
  784. new_coords = np.matmul(shear_matrix, p) # [2, n]
  785. new_coords[0, :] = np.clip(new_coords[0, :], 0,
  786. out_shape[1])
  787. new_coords[1, :] = np.clip(new_coords[1, :], 0,
  788. out_shape[0])
  789. sheared_poly.append(
  790. new_coords.transpose((1, 0)).reshape(-1))
  791. sheared_masks.append(sheared_poly)
  792. sheared_masks = PolygonMasks(sheared_masks, *out_shape)
  793. return sheared_masks
  794. def rotate(self,
  795. out_shape,
  796. angle,
  797. center=None,
  798. scale=1.0,
  799. border_value=0,
  800. interpolation='bilinear'):
  801. """See :func:`BaseInstanceMasks.rotate`."""
  802. if len(self.masks) == 0:
  803. rotated_masks = PolygonMasks([], *out_shape)
  804. else:
  805. rotated_masks = []
  806. rotate_matrix = cv2.getRotationMatrix2D(center, -angle, scale)
  807. for poly_per_obj in self.masks:
  808. rotated_poly = []
  809. for p in poly_per_obj:
  810. p = p.copy()
  811. coords = np.stack([p[0::2], p[1::2]], axis=1) # [n, 2]
  812. # pad 1 to convert from format [x, y] to homogeneous
  813. # coordinates format [x, y, 1]
  814. coords = np.concatenate(
  815. (coords, np.ones((coords.shape[0], 1), coords.dtype)),
  816. axis=1) # [n, 3]
  817. rotated_coords = np.matmul(
  818. rotate_matrix[None, :, :],
  819. coords[:, :, None])[..., 0] # [n, 2, 1] -> [n, 2]
  820. rotated_coords[:, 0] = np.clip(rotated_coords[:, 0], 0,
  821. out_shape[1])
  822. rotated_coords[:, 1] = np.clip(rotated_coords[:, 1], 0,
  823. out_shape[0])
  824. rotated_poly.append(rotated_coords.reshape(-1))
  825. rotated_masks.append(rotated_poly)
  826. rotated_masks = PolygonMasks(rotated_masks, *out_shape)
  827. return rotated_masks
  828. def to_bitmap(self):
  829. """convert polygon masks to bitmap masks."""
  830. bitmap_masks = self.to_ndarray()
  831. return BitmapMasks(bitmap_masks, self.height, self.width)
  832. @property
  833. def areas(self):
  834. """Compute areas of masks.
  835. This func is modified from `detectron2
  836. <https://github.com/facebookresearch/detectron2/blob/ffff8acc35ea88ad1cb1806ab0f00b4c1c5dbfd9/detectron2/structures/masks.py#L387>`_.
  837. The function only works with Polygons using the shoelace formula.
  838. Return:
  839. ndarray: areas of each instance
  840. """ # noqa: W501
  841. area = []
  842. for polygons_per_obj in self.masks:
  843. area_per_obj = 0
  844. for p in polygons_per_obj:
  845. area_per_obj += self._polygon_area(p[0::2], p[1::2])
  846. area.append(area_per_obj)
  847. return np.asarray(area)
  848. def _polygon_area(self, x, y):
  849. """Compute the area of a component of a polygon.
  850. Using the shoelace formula:
  851. https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates
  852. Args:
  853. x (ndarray): x coordinates of the component
  854. y (ndarray): y coordinates of the component
  855. Return:
  856. float: the are of the component
  857. """ # noqa: 501
  858. return 0.5 * np.abs(
  859. np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))
  860. def to_ndarray(self):
  861. """Convert masks to the format of ndarray."""
  862. if len(self.masks) == 0:
  863. return np.empty((0, self.height, self.width), dtype=np.uint8)
  864. bitmap_masks = []
  865. for poly_per_obj in self.masks:
  866. bitmap_masks.append(
  867. polygon_to_bitmap(poly_per_obj, self.height, self.width))
  868. return np.stack(bitmap_masks)
  869. def to_tensor(self, dtype, device):
  870. """See :func:`BaseInstanceMasks.to_tensor`."""
  871. if len(self.masks) == 0:
  872. return torch.empty((0, self.height, self.width),
  873. dtype=dtype,
  874. device=device)
  875. ndarray_masks = self.to_ndarray()
  876. return torch.tensor(ndarray_masks, dtype=dtype, device=device)
  877. @classmethod
  878. def random(cls,
  879. num_masks=3,
  880. height=32,
  881. width=32,
  882. n_verts=5,
  883. dtype=np.float32,
  884. rng=None):
  885. """Generate random polygon masks for demo / testing purposes.
  886. Adapted from [1]_
  887. References:
  888. .. [1] https://gitlab.kitware.com/computer-vision/kwimage/-/blob/928cae35ca8/kwimage/structs/polygon.py#L379 # noqa: E501
  889. Example:
  890. >>> from mmdet.data_elements.mask.structures import PolygonMasks
  891. >>> self = PolygonMasks.random()
  892. >>> print('self = {}'.format(self))
  893. """
  894. from mmdet.utils.util_random import ensure_rng
  895. rng = ensure_rng(rng)
  896. def _gen_polygon(n, irregularity, spikeyness):
  897. """Creates the polygon by sampling points on a circle around the
  898. centre. Random noise is added by varying the angular spacing
  899. between sequential points, and by varying the radial distance of
  900. each point from the centre.
  901. Based on original code by Mike Ounsworth
  902. Args:
  903. n (int): number of vertices
  904. irregularity (float): [0,1] indicating how much variance there
  905. is in the angular spacing of vertices. [0,1] will map to
  906. [0, 2pi/numberOfVerts]
  907. spikeyness (float): [0,1] indicating how much variance there is
  908. in each vertex from the circle of radius aveRadius. [0,1]
  909. will map to [0, aveRadius]
  910. Returns:
  911. a list of vertices, in CCW order.
  912. """
  913. from scipy.stats import truncnorm
  914. # Generate around the unit circle
  915. cx, cy = (0.0, 0.0)
  916. radius = 1
  917. tau = np.pi * 2
  918. irregularity = np.clip(irregularity, 0, 1) * 2 * np.pi / n
  919. spikeyness = np.clip(spikeyness, 1e-9, 1)
  920. # generate n angle steps
  921. lower = (tau / n) - irregularity
  922. upper = (tau / n) + irregularity
  923. angle_steps = rng.uniform(lower, upper, n)
  924. # normalize the steps so that point 0 and point n+1 are the same
  925. k = angle_steps.sum() / (2 * np.pi)
  926. angles = (angle_steps / k).cumsum() + rng.uniform(0, tau)
  927. # Convert high and low values to be wrt the standard normal range
  928. # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.truncnorm.html
  929. low = 0
  930. high = 2 * radius
  931. mean = radius
  932. std = spikeyness
  933. a = (low - mean) / std
  934. b = (high - mean) / std
  935. tnorm = truncnorm(a=a, b=b, loc=mean, scale=std)
  936. # now generate the points
  937. radii = tnorm.rvs(n, random_state=rng)
  938. x_pts = cx + radii * np.cos(angles)
  939. y_pts = cy + radii * np.sin(angles)
  940. points = np.hstack([x_pts[:, None], y_pts[:, None]])
  941. # Scale to 0-1 space
  942. points = points - points.min(axis=0)
  943. points = points / points.max(axis=0)
  944. # Randomly place within 0-1 space
  945. points = points * (rng.rand() * .8 + .2)
  946. min_pt = points.min(axis=0)
  947. max_pt = points.max(axis=0)
  948. high = (1 - max_pt)
  949. low = (0 - min_pt)
  950. offset = (rng.rand(2) * (high - low)) + low
  951. points = points + offset
  952. return points
  953. def _order_vertices(verts):
  954. """
  955. References:
  956. https://stackoverflow.com/questions/1709283/how-can-i-sort-a-coordinate-list-for-a-rectangle-counterclockwise
  957. """
  958. mlat = verts.T[0].sum() / len(verts)
  959. mlng = verts.T[1].sum() / len(verts)
  960. tau = np.pi * 2
  961. angle = (np.arctan2(mlat - verts.T[0], verts.T[1] - mlng) +
  962. tau) % tau
  963. sortx = angle.argsort()
  964. verts = verts.take(sortx, axis=0)
  965. return verts
  966. # Generate a random exterior for each requested mask
  967. masks = []
  968. for _ in range(num_masks):
  969. exterior = _order_vertices(_gen_polygon(n_verts, 0.9, 0.9))
  970. exterior = (exterior * [(width, height)]).astype(dtype)
  971. masks.append([exterior.ravel()])
  972. self = cls(masks, height, width)
  973. return self
  974. @classmethod
  975. def cat(cls: Type[T], masks: Sequence[T]) -> T:
  976. """Concatenate a sequence of masks into one single mask instance.
  977. Args:
  978. masks (Sequence[PolygonMasks]): A sequence of mask instances.
  979. Returns:
  980. PolygonMasks: Concatenated mask instance.
  981. """
  982. assert isinstance(masks, Sequence)
  983. if len(masks) == 0:
  984. raise ValueError('masks should not be an empty list.')
  985. assert all(isinstance(m, cls) for m in masks)
  986. mask_list = list(itertools.chain(*[m.masks for m in masks]))
  987. return cls(mask_list, masks[0].height, masks[0].width)
  988. def polygon_to_bitmap(polygons, height, width):
  989. """Convert masks from the form of polygons to bitmaps.
  990. Args:
  991. polygons (list[ndarray]): masks in polygon representation
  992. height (int): mask height
  993. width (int): mask width
  994. Return:
  995. ndarray: the converted masks in bitmap representation
  996. """
  997. rles = maskUtils.frPyObjects(polygons, height, width)
  998. rle = maskUtils.merge(rles)
  999. bitmap_mask = maskUtils.decode(rle).astype(bool)
  1000. return bitmap_mask
  1001. def bitmap_to_polygon(bitmap):
  1002. """Convert masks from the form of bitmaps to polygons.
  1003. Args:
  1004. bitmap (ndarray): masks in bitmap representation.
  1005. Return:
  1006. list[ndarray]: the converted mask in polygon representation.
  1007. bool: whether the mask has holes.
  1008. """
  1009. bitmap = np.ascontiguousarray(bitmap).astype(np.uint8)
  1010. # cv2.RETR_CCOMP: retrieves all of the contours and organizes them
  1011. # into a two-level hierarchy. At the top level, there are external
  1012. # boundaries of the components. At the second level, there are
  1013. # boundaries of the holes. If there is another contour inside a hole
  1014. # of a connected component, it is still put at the top level.
  1015. # cv2.CHAIN_APPROX_NONE: stores absolutely all the contour points.
  1016. outs = cv2.findContours(bitmap, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
  1017. contours = outs[-2]
  1018. hierarchy = outs[-1]
  1019. if hierarchy is None:
  1020. return [], False
  1021. # hierarchy[i]: 4 elements, for the indexes of next, previous,
  1022. # parent, or nested contours. If there is no corresponding contour,
  1023. # it will be -1.
  1024. with_hole = (hierarchy.reshape(-1, 4)[:, 3] >= 0).any()
  1025. contours = [c.reshape(-1, 2) for c in contours]
  1026. return contours, with_hole