transforms.py 137 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import inspect
  4. import math
  5. from typing import List, Optional, Sequence, Tuple, Union
  6. import cv2
  7. import mmcv
  8. import numpy as np
  9. from mmcv.image.geometric import _scale_size
  10. from mmcv.transforms import BaseTransform
  11. from mmcv.transforms import Pad as MMCV_Pad
  12. from mmcv.transforms import RandomFlip as MMCV_RandomFlip
  13. from mmcv.transforms import Resize as MMCV_Resize
  14. from mmcv.transforms.utils import avoid_cache_randomness, cache_randomness
  15. from mmengine.dataset import BaseDataset
  16. from mmengine.utils import is_str
  17. from numpy import random
  18. from mmdet.registry import TRANSFORMS
  19. from mmdet.structures.bbox import HorizontalBoxes, autocast_box_type
  20. from mmdet.structures.mask import BitmapMasks, PolygonMasks
  21. from mmdet.utils import log_img_scale
  22. try:
  23. from imagecorruptions import corrupt
  24. except ImportError:
  25. corrupt = None
  26. try:
  27. import albumentations
  28. from albumentations import Compose
  29. except ImportError:
  30. albumentations = None
  31. Compose = None
  32. Number = Union[int, float]
  33. @TRANSFORMS.register_module()
  34. class Resize(MMCV_Resize):
  35. """Resize images & bbox & seg.
  36. This transform resizes the input image according to ``scale`` or
  37. ``scale_factor``. Bboxes, masks, and seg map are then resized
  38. with the same scale factor.
  39. if ``scale`` and ``scale_factor`` are both set, it will use ``scale`` to
  40. resize.
  41. Required Keys:
  42. - img
  43. - gt_bboxes (BaseBoxes[torch.float32]) (optional)
  44. - gt_masks (BitmapMasks | PolygonMasks) (optional)
  45. - gt_seg_map (np.uint8) (optional)
  46. Modified Keys:
  47. - img
  48. - img_shape
  49. - gt_bboxes
  50. - gt_masks
  51. - gt_seg_map
  52. Added Keys:
  53. - scale
  54. - scale_factor
  55. - keep_ratio
  56. - homography_matrix
  57. Args:
  58. scale (int or tuple): Images scales for resizing. Defaults to None
  59. scale_factor (float or tuple[float]): Scale factors for resizing.
  60. Defaults to None.
  61. keep_ratio (bool): Whether to keep the aspect ratio when resizing the
  62. image. Defaults to False.
  63. clip_object_border (bool): Whether to clip the objects
  64. outside the border of the image. In some dataset like MOT17, the gt
  65. bboxes are allowed to cross the border of images. Therefore, we
  66. don't need to clip the gt bboxes in these cases. Defaults to True.
  67. backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
  68. These two backends generates slightly different results. Defaults
  69. to 'cv2'.
  70. interpolation (str): Interpolation method, accepted values are
  71. "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
  72. backend, "nearest", "bilinear" for 'pillow' backend. Defaults
  73. to 'bilinear'.
  74. """
  75. def _resize_masks(self, results: dict) -> None:
  76. """Resize masks with ``results['scale']``"""
  77. if results.get('gt_masks', None) is not None:
  78. if self.keep_ratio:
  79. results['gt_masks'] = results['gt_masks'].rescale(
  80. results['scale'])
  81. else:
  82. results['gt_masks'] = results['gt_masks'].resize(
  83. results['img_shape'])
  84. def _resize_bboxes(self, results: dict) -> None:
  85. """Resize bounding boxes with ``results['scale_factor']``."""
  86. if results.get('gt_bboxes', None) is not None:
  87. results['gt_bboxes'].rescale_(results['scale_factor'])
  88. if self.clip_object_border:
  89. results['gt_bboxes'].clip_(results['img_shape'])
  90. def _resize_seg(self, results: dict) -> None:
  91. """Resize semantic segmentation map with ``results['scale']``."""
  92. if results.get('gt_seg_map', None) is not None:
  93. if self.keep_ratio:
  94. gt_seg = mmcv.imrescale(
  95. results['gt_seg_map'],
  96. results['scale'],
  97. interpolation='nearest',
  98. backend=self.backend)
  99. else:
  100. gt_seg = mmcv.imresize(
  101. results['gt_seg_map'],
  102. results['scale'],
  103. interpolation='nearest',
  104. backend=self.backend)
  105. results['gt_seg_map'] = gt_seg
  106. def _record_homography_matrix(self, results: dict) -> None:
  107. """Record the homography matrix for the Resize."""
  108. w_scale, h_scale = results['scale_factor']
  109. homography_matrix = np.array(
  110. [[w_scale, 0, 0], [0, h_scale, 0], [0, 0, 1]], dtype=np.float32)
  111. if results.get('homography_matrix', None) is None:
  112. results['homography_matrix'] = homography_matrix
  113. else:
  114. results['homography_matrix'] = homography_matrix @ results[
  115. 'homography_matrix']
  116. @autocast_box_type()
  117. def transform(self, results: dict) -> dict:
  118. """Transform function to resize images, bounding boxes and semantic
  119. segmentation map.
  120. Args:
  121. results (dict): Result dict from loading pipeline.
  122. Returns:
  123. dict: Resized results, 'img', 'gt_bboxes', 'gt_seg_map',
  124. 'scale', 'scale_factor', 'height', 'width', and 'keep_ratio' keys
  125. are updated in result dict.
  126. """
  127. if self.scale:
  128. results['scale'] = self.scale
  129. else:
  130. img_shape = results['img'].shape[:2]
  131. results['scale'] = _scale_size(img_shape[::-1], self.scale_factor)
  132. self._resize_img(results)
  133. self._resize_bboxes(results)
  134. self._resize_masks(results)
  135. self._resize_seg(results)
  136. self._record_homography_matrix(results)
  137. return results
  138. def __repr__(self) -> str:
  139. repr_str = self.__class__.__name__
  140. repr_str += f'(scale={self.scale}, '
  141. repr_str += f'scale_factor={self.scale_factor}, '
  142. repr_str += f'keep_ratio={self.keep_ratio}, '
  143. repr_str += f'clip_object_border={self.clip_object_border}), '
  144. repr_str += f'backend={self.backend}), '
  145. repr_str += f'interpolation={self.interpolation})'
  146. return repr_str
  147. @TRANSFORMS.register_module()
  148. class FixShapeResize(Resize):
  149. """Resize images & bbox & seg to the specified size.
  150. This transform resizes the input image according to ``width`` and
  151. ``height``. Bboxes, masks, and seg map are then resized
  152. with the same parameters.
  153. Required Keys:
  154. - img
  155. - gt_bboxes (BaseBoxes[torch.float32]) (optional)
  156. - gt_masks (BitmapMasks | PolygonMasks) (optional)
  157. - gt_seg_map (np.uint8) (optional)
  158. Modified Keys:
  159. - img
  160. - img_shape
  161. - gt_bboxes
  162. - gt_masks
  163. - gt_seg_map
  164. Added Keys:
  165. - scale
  166. - scale_factor
  167. - keep_ratio
  168. - homography_matrix
  169. Args:
  170. width (int): width for resizing.
  171. height (int): height for resizing.
  172. Defaults to None.
  173. pad_val (Number | dict[str, Number], optional): Padding value for if
  174. the pad_mode is "constant". If it is a single number, the value
  175. to pad the image is the number and to pad the semantic
  176. segmentation map is 255. If it is a dict, it should have the
  177. following keys:
  178. - img: The value to pad the image.
  179. - seg: The value to pad the semantic segmentation map.
  180. Defaults to dict(img=0, seg=255).
  181. keep_ratio (bool): Whether to keep the aspect ratio when resizing the
  182. image. Defaults to False.
  183. clip_object_border (bool): Whether to clip the objects
  184. outside the border of the image. In some dataset like MOT17, the gt
  185. bboxes are allowed to cross the border of images. Therefore, we
  186. don't need to clip the gt bboxes in these cases. Defaults to True.
  187. backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
  188. These two backends generates slightly different results. Defaults
  189. to 'cv2'.
  190. interpolation (str): Interpolation method, accepted values are
  191. "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
  192. backend, "nearest", "bilinear" for 'pillow' backend. Defaults
  193. to 'bilinear'.
  194. """
  195. def __init__(self,
  196. width: int,
  197. height: int,
  198. pad_val: Union[Number, dict] = dict(img=0, seg=255),
  199. keep_ratio: bool = False,
  200. clip_object_border: bool = True,
  201. backend: str = 'cv2',
  202. interpolation: str = 'bilinear') -> None:
  203. assert width is not None and height is not None, (
  204. '`width` and'
  205. '`height` can not be `None`')
  206. self.width = width
  207. self.height = height
  208. self.scale = (width, height)
  209. self.backend = backend
  210. self.interpolation = interpolation
  211. self.keep_ratio = keep_ratio
  212. self.clip_object_border = clip_object_border
  213. if keep_ratio is True:
  214. # padding to the fixed size when keep_ratio=True
  215. self.pad_transform = Pad(size=self.scale, pad_val=pad_val)
  216. @autocast_box_type()
  217. def transform(self, results: dict) -> dict:
  218. """Transform function to resize images, bounding boxes and semantic
  219. segmentation map.
  220. Args:
  221. results (dict): Result dict from loading pipeline.
  222. Returns:
  223. dict: Resized results, 'img', 'gt_bboxes', 'gt_seg_map',
  224. 'scale', 'scale_factor', 'height', 'width', and 'keep_ratio' keys
  225. are updated in result dict.
  226. """
  227. img = results['img']
  228. h, w = img.shape[:2]
  229. if self.keep_ratio:
  230. scale_factor = min(self.width / w, self.height / h)
  231. results['scale_factor'] = (scale_factor, scale_factor)
  232. real_w, real_h = int(w * float(scale_factor) +
  233. 0.5), int(h * float(scale_factor) + 0.5)
  234. img, scale_factor = mmcv.imrescale(
  235. results['img'], (real_w, real_h),
  236. interpolation=self.interpolation,
  237. return_scale=True,
  238. backend=self.backend)
  239. # the w_scale and h_scale has minor difference
  240. # a real fix should be done in the mmcv.imrescale in the future
  241. results['img'] = img
  242. results['img_shape'] = img.shape[:2]
  243. results['keep_ratio'] = self.keep_ratio
  244. results['scale'] = (real_w, real_h)
  245. else:
  246. results['scale'] = (self.width, self.height)
  247. results['scale_factor'] = (self.width / w, self.height / h)
  248. super()._resize_img(results)
  249. self._resize_bboxes(results)
  250. self._resize_masks(results)
  251. self._resize_seg(results)
  252. self._record_homography_matrix(results)
  253. if self.keep_ratio:
  254. self.pad_transform(results)
  255. return results
  256. def __repr__(self) -> str:
  257. repr_str = self.__class__.__name__
  258. repr_str += f'(width={self.width}, height={self.height}, '
  259. repr_str += f'keep_ratio={self.keep_ratio}, '
  260. repr_str += f'clip_object_border={self.clip_object_border}), '
  261. repr_str += f'backend={self.backend}), '
  262. repr_str += f'interpolation={self.interpolation})'
  263. return repr_str
  264. @TRANSFORMS.register_module()
  265. class RandomFlip(MMCV_RandomFlip):
  266. """Flip the image & bbox & mask & segmentation map. Added or Updated keys:
  267. flip, flip_direction, img, gt_bboxes, and gt_seg_map. There are 3 flip
  268. modes:
  269. - ``prob`` is float, ``direction`` is string: the image will be
  270. ``direction``ly flipped with probability of ``prob`` .
  271. E.g., ``prob=0.5``, ``direction='horizontal'``,
  272. then image will be horizontally flipped with probability of 0.5.
  273. - ``prob`` is float, ``direction`` is list of string: the image will
  274. be ``direction[i]``ly flipped with probability of
  275. ``prob/len(direction)``.
  276. E.g., ``prob=0.5``, ``direction=['horizontal', 'vertical']``,
  277. then image will be horizontally flipped with probability of 0.25,
  278. vertically with probability of 0.25.
  279. - ``prob`` is list of float, ``direction`` is list of string:
  280. given ``len(prob) == len(direction)``, the image will
  281. be ``direction[i]``ly flipped with probability of ``prob[i]``.
  282. E.g., ``prob=[0.3, 0.5]``, ``direction=['horizontal',
  283. 'vertical']``, then image will be horizontally flipped with
  284. probability of 0.3, vertically with probability of 0.5.
  285. Required Keys:
  286. - img
  287. - gt_bboxes (BaseBoxes[torch.float32]) (optional)
  288. - gt_masks (BitmapMasks | PolygonMasks) (optional)
  289. - gt_seg_map (np.uint8) (optional)
  290. Modified Keys:
  291. - img
  292. - gt_bboxes
  293. - gt_masks
  294. - gt_seg_map
  295. Added Keys:
  296. - flip
  297. - flip_direction
  298. - homography_matrix
  299. Args:
  300. prob (float | list[float], optional): The flipping probability.
  301. Defaults to None.
  302. direction(str | list[str]): The flipping direction. Options
  303. If input is a list, the length must equal ``prob``. Each
  304. element in ``prob`` indicates the flip probability of
  305. corresponding direction. Defaults to 'horizontal'.
  306. """
  307. def _record_homography_matrix(self, results: dict) -> None:
  308. """Record the homography matrix for the RandomFlip."""
  309. cur_dir = results['flip_direction']
  310. h, w = results['img'].shape[:2]
  311. if cur_dir == 'horizontal':
  312. homography_matrix = np.array([[-1, 0, w], [0, 1, 0], [0, 0, 1]],
  313. dtype=np.float32)
  314. elif cur_dir == 'vertical':
  315. homography_matrix = np.array([[1, 0, 0], [0, -1, h], [0, 0, 1]],
  316. dtype=np.float32)
  317. elif cur_dir == 'diagonal':
  318. homography_matrix = np.array([[-1, 0, w], [0, -1, h], [0, 0, 1]],
  319. dtype=np.float32)
  320. else:
  321. homography_matrix = np.eye(3, dtype=np.float32)
  322. if results.get('homography_matrix', None) is None:
  323. results['homography_matrix'] = homography_matrix
  324. else:
  325. results['homography_matrix'] = homography_matrix @ results[
  326. 'homography_matrix']
  327. @autocast_box_type()
  328. def _flip(self, results: dict) -> None:
  329. """Flip images, bounding boxes, and semantic segmentation map."""
  330. # flip image
  331. results['img'] = mmcv.imflip(
  332. results['img'], direction=results['flip_direction'])
  333. img_shape = results['img'].shape[:2]
  334. # flip bboxes
  335. if results.get('gt_bboxes', None) is not None:
  336. results['gt_bboxes'].flip_(img_shape, results['flip_direction'])
  337. # flip masks
  338. if results.get('gt_masks', None) is not None:
  339. results['gt_masks'] = results['gt_masks'].flip(
  340. results['flip_direction'])
  341. # flip segs
  342. if results.get('gt_seg_map', None) is not None:
  343. results['gt_seg_map'] = mmcv.imflip(
  344. results['gt_seg_map'], direction=results['flip_direction'])
  345. # record homography matrix for flip
  346. self._record_homography_matrix(results)
  347. @TRANSFORMS.register_module()
  348. class RandomShift(BaseTransform):
  349. """Shift the image and box given shift pixels and probability.
  350. Required Keys:
  351. - img
  352. - gt_bboxes (BaseBoxes[torch.float32])
  353. - gt_bboxes_labels (np.int64)
  354. - gt_ignore_flags (bool) (optional)
  355. Modified Keys:
  356. - img
  357. - gt_bboxes
  358. - gt_bboxes_labels
  359. - gt_ignore_flags (bool) (optional)
  360. Args:
  361. prob (float): Probability of shifts. Defaults to 0.5.
  362. max_shift_px (int): The max pixels for shifting. Defaults to 32.
  363. filter_thr_px (int): The width and height threshold for filtering.
  364. The bbox and the rest of the targets below the width and
  365. height threshold will be filtered. Defaults to 1.
  366. """
  367. def __init__(self,
  368. prob: float = 0.5,
  369. max_shift_px: int = 32,
  370. filter_thr_px: int = 1) -> None:
  371. assert 0 <= prob <= 1
  372. assert max_shift_px >= 0
  373. self.prob = prob
  374. self.max_shift_px = max_shift_px
  375. self.filter_thr_px = int(filter_thr_px)
  376. @cache_randomness
  377. def _random_prob(self) -> float:
  378. return random.uniform(0, 1)
  379. @autocast_box_type()
  380. def transform(self, results: dict) -> dict:
  381. """Transform function to random shift images, bounding boxes.
  382. Args:
  383. results (dict): Result dict from loading pipeline.
  384. Returns:
  385. dict: Shift results.
  386. """
  387. if self._random_prob() < self.prob:
  388. img_shape = results['img'].shape[:2]
  389. random_shift_x = random.randint(-self.max_shift_px,
  390. self.max_shift_px)
  391. random_shift_y = random.randint(-self.max_shift_px,
  392. self.max_shift_px)
  393. new_x = max(0, random_shift_x)
  394. ori_x = max(0, -random_shift_x)
  395. new_y = max(0, random_shift_y)
  396. ori_y = max(0, -random_shift_y)
  397. # TODO: support mask and semantic segmentation maps.
  398. bboxes = results['gt_bboxes'].clone()
  399. bboxes.translate_([random_shift_x, random_shift_y])
  400. # clip border
  401. bboxes.clip_(img_shape)
  402. # remove invalid bboxes
  403. valid_inds = (bboxes.widths > self.filter_thr_px).numpy() & (
  404. bboxes.heights > self.filter_thr_px).numpy()
  405. # If the shift does not contain any gt-bbox area, skip this
  406. # image.
  407. if not valid_inds.any():
  408. return results
  409. bboxes = bboxes[valid_inds]
  410. results['gt_bboxes'] = bboxes
  411. results['gt_bboxes_labels'] = results['gt_bboxes_labels'][
  412. valid_inds]
  413. if results.get('gt_ignore_flags', None) is not None:
  414. results['gt_ignore_flags'] = \
  415. results['gt_ignore_flags'][valid_inds]
  416. # shift img
  417. img = results['img']
  418. new_img = np.zeros_like(img)
  419. img_h, img_w = img.shape[:2]
  420. new_h = img_h - np.abs(random_shift_y)
  421. new_w = img_w - np.abs(random_shift_x)
  422. new_img[new_y:new_y + new_h, new_x:new_x + new_w] \
  423. = img[ori_y:ori_y + new_h, ori_x:ori_x + new_w]
  424. results['img'] = new_img
  425. return results
  426. def __repr__(self):
  427. repr_str = self.__class__.__name__
  428. repr_str += f'(prob={self.prob}, '
  429. repr_str += f'max_shift_px={self.max_shift_px}, '
  430. repr_str += f'filter_thr_px={self.filter_thr_px})'
  431. return repr_str
  432. @TRANSFORMS.register_module()
  433. class Pad(MMCV_Pad):
  434. """Pad the image & segmentation map.
  435. There are three padding modes: (1) pad to a fixed size and (2) pad to the
  436. minimum size that is divisible by some number. and (3)pad to square. Also,
  437. pad to square and pad to the minimum size can be used as the same time.
  438. Required Keys:
  439. - img
  440. - gt_bboxes (BaseBoxes[torch.float32]) (optional)
  441. - gt_masks (BitmapMasks | PolygonMasks) (optional)
  442. - gt_seg_map (np.uint8) (optional)
  443. Modified Keys:
  444. - img
  445. - img_shape
  446. - gt_masks
  447. - gt_seg_map
  448. Added Keys:
  449. - pad_shape
  450. - pad_fixed_size
  451. - pad_size_divisor
  452. Args:
  453. size (tuple, optional): Fixed padding size.
  454. Expected padding shape (width, height). Defaults to None.
  455. size_divisor (int, optional): The divisor of padded size. Defaults to
  456. None.
  457. pad_to_square (bool): Whether to pad the image into a square.
  458. Currently only used for YOLOX. Defaults to False.
  459. pad_val (Number | dict[str, Number], optional) - Padding value for if
  460. the pad_mode is "constant". If it is a single number, the value
  461. to pad the image is the number and to pad the semantic
  462. segmentation map is 255. If it is a dict, it should have the
  463. following keys:
  464. - img: The value to pad the image.
  465. - seg: The value to pad the semantic segmentation map.
  466. Defaults to dict(img=0, seg=255).
  467. padding_mode (str): Type of padding. Should be: constant, edge,
  468. reflect or symmetric. Defaults to 'constant'.
  469. - constant: pads with a constant value, this value is specified
  470. with pad_val.
  471. - edge: pads with the last value at the edge of the image.
  472. - reflect: pads with reflection of image without repeating the last
  473. value on the edge. For example, padding [1, 2, 3, 4] with 2
  474. elements on both sides in reflect mode will result in
  475. [3, 2, 1, 2, 3, 4, 3, 2].
  476. - symmetric: pads with reflection of image repeating the last value
  477. on the edge. For example, padding [1, 2, 3, 4] with 2 elements on
  478. both sides in symmetric mode will result in
  479. [2, 1, 1, 2, 3, 4, 4, 3]
  480. """
  481. def _pad_masks(self, results: dict) -> None:
  482. """Pad masks according to ``results['pad_shape']``."""
  483. if results.get('gt_masks', None) is not None:
  484. pad_val = self.pad_val.get('masks', 0)
  485. pad_shape = results['pad_shape'][:2]
  486. results['gt_masks'] = results['gt_masks'].pad(
  487. pad_shape, pad_val=pad_val)
  488. def transform(self, results: dict) -> dict:
  489. """Call function to pad images, masks, semantic segmentation maps.
  490. Args:
  491. results (dict): Result dict from loading pipeline.
  492. Returns:
  493. dict: Updated result dict.
  494. """
  495. self._pad_img(results)
  496. self._pad_seg(results)
  497. self._pad_masks(results)
  498. return results
  499. @TRANSFORMS.register_module()
  500. class RandomCrop(BaseTransform):
  501. """Random crop the image & bboxes & masks.
  502. The absolute ``crop_size`` is sampled based on ``crop_type`` and
  503. ``image_size``, then the cropped results are generated.
  504. Required Keys:
  505. - img
  506. - gt_bboxes (BaseBoxes[torch.float32]) (optional)
  507. - gt_bboxes_labels (np.int64) (optional)
  508. - gt_masks (BitmapMasks | PolygonMasks) (optional)
  509. - gt_ignore_flags (bool) (optional)
  510. - gt_seg_map (np.uint8) (optional)
  511. Modified Keys:
  512. - img
  513. - img_shape
  514. - gt_bboxes (optional)
  515. - gt_bboxes_labels (optional)
  516. - gt_masks (optional)
  517. - gt_ignore_flags (optional)
  518. - gt_seg_map (optional)
  519. Added Keys:
  520. - homography_matrix
  521. Args:
  522. crop_size (tuple): The relative ratio or absolute pixels of
  523. (width, height).
  524. crop_type (str, optional): One of "relative_range", "relative",
  525. "absolute", "absolute_range". "relative" randomly crops
  526. (h * crop_size[0], w * crop_size[1]) part from an input of size
  527. (h, w). "relative_range" uniformly samples relative crop size from
  528. range [crop_size[0], 1] and [crop_size[1], 1] for height and width
  529. respectively. "absolute" crops from an input with absolute size
  530. (crop_size[0], crop_size[1]). "absolute_range" uniformly samples
  531. crop_h in range [crop_size[0], min(h, crop_size[1])] and crop_w
  532. in range [crop_size[0], min(w, crop_size[1])].
  533. Defaults to "absolute".
  534. allow_negative_crop (bool, optional): Whether to allow a crop that does
  535. not contain any bbox area. Defaults to False.
  536. recompute_bbox (bool, optional): Whether to re-compute the boxes based
  537. on cropped instance masks. Defaults to False.
  538. bbox_clip_border (bool, optional): Whether clip the objects outside
  539. the border of the image. Defaults to True.
  540. Note:
  541. - If the image is smaller than the absolute crop size, return the
  542. original image.
  543. - The keys for bboxes, labels and masks must be aligned. That is,
  544. ``gt_bboxes`` corresponds to ``gt_labels`` and ``gt_masks``, and
  545. ``gt_bboxes_ignore`` corresponds to ``gt_labels_ignore`` and
  546. ``gt_masks_ignore``.
  547. - If the crop does not contain any gt-bbox region and
  548. ``allow_negative_crop`` is set to False, skip this image.
  549. """
  550. def __init__(self,
  551. crop_size: tuple,
  552. crop_type: str = 'absolute',
  553. allow_negative_crop: bool = False,
  554. recompute_bbox: bool = False,
  555. bbox_clip_border: bool = True) -> None:
  556. if crop_type not in [
  557. 'relative_range', 'relative', 'absolute', 'absolute_range'
  558. ]:
  559. raise ValueError(f'Invalid crop_type {crop_type}.')
  560. if crop_type in ['absolute', 'absolute_range']:
  561. assert crop_size[0] > 0 and crop_size[1] > 0
  562. assert isinstance(crop_size[0], int) and isinstance(
  563. crop_size[1], int)
  564. if crop_type == 'absolute_range':
  565. assert crop_size[0] <= crop_size[1]
  566. else:
  567. assert 0 < crop_size[0] <= 1 and 0 < crop_size[1] <= 1
  568. self.crop_size = crop_size
  569. self.crop_type = crop_type
  570. self.allow_negative_crop = allow_negative_crop
  571. self.bbox_clip_border = bbox_clip_border
  572. self.recompute_bbox = recompute_bbox
  573. def _crop_data(self, results: dict, crop_size: Tuple[int, int],
  574. allow_negative_crop: bool) -> Union[dict, None]:
  575. """Function to randomly crop images, bounding boxes, masks, semantic
  576. segmentation maps.
  577. Args:
  578. results (dict): Result dict from loading pipeline.
  579. crop_size (Tuple[int, int]): Expected absolute size after
  580. cropping, (h, w).
  581. allow_negative_crop (bool): Whether to allow a crop that does not
  582. contain any bbox area.
  583. Returns:
  584. results (Union[dict, None]): Randomly cropped results, 'img_shape'
  585. key in result dict is updated according to crop size. None will
  586. be returned when there is no valid bbox after cropping.
  587. """
  588. assert crop_size[0] > 0 and crop_size[1] > 0
  589. img = results['img']
  590. margin_h = max(img.shape[0] - crop_size[0], 0)
  591. margin_w = max(img.shape[1] - crop_size[1], 0)
  592. offset_h, offset_w = self._rand_offset((margin_h, margin_w))
  593. crop_y1, crop_y2 = offset_h, offset_h + crop_size[0]
  594. crop_x1, crop_x2 = offset_w, offset_w + crop_size[1]
  595. # Record the homography matrix for the RandomCrop
  596. homography_matrix = np.array(
  597. [[1, 0, -offset_w], [0, 1, -offset_h], [0, 0, 1]],
  598. dtype=np.float32)
  599. if results.get('homography_matrix', None) is None:
  600. results['homography_matrix'] = homography_matrix
  601. else:
  602. results['homography_matrix'] = homography_matrix @ results[
  603. 'homography_matrix']
  604. # crop the image
  605. img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
  606. img_shape = img.shape
  607. results['img'] = img
  608. results['img_shape'] = img_shape[:2]
  609. # crop bboxes accordingly and clip to the image boundary
  610. if results.get('gt_bboxes', None) is not None:
  611. bboxes = results['gt_bboxes']
  612. bboxes.translate_([-offset_w, -offset_h])
  613. if self.bbox_clip_border:
  614. bboxes.clip_(img_shape[:2])
  615. valid_inds = bboxes.is_inside(img_shape[:2]).numpy()
  616. # If the crop does not contain any gt-bbox area and
  617. # allow_negative_crop is False, skip this image.
  618. if (not valid_inds.any() and not allow_negative_crop):
  619. return None
  620. results['gt_bboxes'] = bboxes[valid_inds]
  621. if results.get('gt_ignore_flags', None) is not None:
  622. results['gt_ignore_flags'] = \
  623. results['gt_ignore_flags'][valid_inds]
  624. if results.get('gt_bboxes_labels', None) is not None:
  625. results['gt_bboxes_labels'] = \
  626. results['gt_bboxes_labels'][valid_inds]
  627. if results.get('gt_masks', None) is not None:
  628. results['gt_masks'] = results['gt_masks'][
  629. valid_inds.nonzero()[0]].crop(
  630. np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))
  631. if self.recompute_bbox:
  632. results['gt_bboxes'] = results['gt_masks'].get_bboxes(
  633. type(results['gt_bboxes']))
  634. # crop semantic seg
  635. if results.get('gt_seg_map', None) is not None:
  636. results['gt_seg_map'] = results['gt_seg_map'][crop_y1:crop_y2,
  637. crop_x1:crop_x2]
  638. return results
  639. @cache_randomness
  640. def _rand_offset(self, margin: Tuple[int, int]) -> Tuple[int, int]:
  641. """Randomly generate crop offset.
  642. Args:
  643. margin (Tuple[int, int]): The upper bound for the offset generated
  644. randomly.
  645. Returns:
  646. Tuple[int, int]: The random offset for the crop.
  647. """
  648. margin_h, margin_w = margin
  649. offset_h = np.random.randint(0, margin_h + 1)
  650. offset_w = np.random.randint(0, margin_w + 1)
  651. return offset_h, offset_w
  652. @cache_randomness
  653. def _get_crop_size(self, image_size: Tuple[int, int]) -> Tuple[int, int]:
  654. """Randomly generates the absolute crop size based on `crop_type` and
  655. `image_size`.
  656. Args:
  657. image_size (Tuple[int, int]): (h, w).
  658. Returns:
  659. crop_size (Tuple[int, int]): (crop_h, crop_w) in absolute pixels.
  660. """
  661. h, w = image_size
  662. if self.crop_type == 'absolute':
  663. return min(self.crop_size[1], h), min(self.crop_size[0], w)
  664. elif self.crop_type == 'absolute_range':
  665. crop_h = np.random.randint(
  666. min(h, self.crop_size[0]),
  667. min(h, self.crop_size[1]) + 1)
  668. crop_w = np.random.randint(
  669. min(w, self.crop_size[0]),
  670. min(w, self.crop_size[1]) + 1)
  671. return crop_h, crop_w
  672. elif self.crop_type == 'relative':
  673. crop_w, crop_h = self.crop_size
  674. return int(h * crop_h + 0.5), int(w * crop_w + 0.5)
  675. else:
  676. # 'relative_range'
  677. crop_size = np.asarray(self.crop_size, dtype=np.float32)
  678. crop_h, crop_w = crop_size + np.random.rand(2) * (1 - crop_size)
  679. return int(h * crop_h + 0.5), int(w * crop_w + 0.5)
  680. @autocast_box_type()
  681. def transform(self, results: dict) -> Union[dict, None]:
  682. """Transform function to randomly crop images, bounding boxes, masks,
  683. semantic segmentation maps.
  684. Args:
  685. results (dict): Result dict from loading pipeline.
  686. Returns:
  687. results (Union[dict, None]): Randomly cropped results, 'img_shape'
  688. key in result dict is updated according to crop size. None will
  689. be returned when there is no valid bbox after cropping.
  690. """
  691. image_size = results['img'].shape[:2]
  692. crop_size = self._get_crop_size(image_size)
  693. results = self._crop_data(results, crop_size, self.allow_negative_crop)
  694. return results
  695. def __repr__(self) -> str:
  696. repr_str = self.__class__.__name__
  697. repr_str += f'(crop_size={self.crop_size}, '
  698. repr_str += f'crop_type={self.crop_type}, '
  699. repr_str += f'allow_negative_crop={self.allow_negative_crop}, '
  700. repr_str += f'recompute_bbox={self.recompute_bbox}, '
  701. repr_str += f'bbox_clip_border={self.bbox_clip_border})'
  702. return repr_str
  703. @TRANSFORMS.register_module()
  704. class SegRescale(BaseTransform):
  705. """Rescale semantic segmentation maps.
  706. This transform rescale the ``gt_seg_map`` according to ``scale_factor``.
  707. Required Keys:
  708. - gt_seg_map
  709. Modified Keys:
  710. - gt_seg_map
  711. Args:
  712. scale_factor (float): The scale factor of the final output. Defaults
  713. to 1.
  714. backend (str): Image rescale backend, choices are 'cv2' and 'pillow'.
  715. These two backends generates slightly different results. Defaults
  716. to 'cv2'.
  717. """
  718. def __init__(self, scale_factor: float = 1, backend: str = 'cv2') -> None:
  719. self.scale_factor = scale_factor
  720. self.backend = backend
  721. def transform(self, results: dict) -> dict:
  722. """Transform function to scale the semantic segmentation map.
  723. Args:
  724. results (dict): Result dict from loading pipeline.
  725. Returns:
  726. dict: Result dict with semantic segmentation map scaled.
  727. """
  728. if self.scale_factor != 1:
  729. results['gt_seg_map'] = mmcv.imrescale(
  730. results['gt_seg_map'],
  731. self.scale_factor,
  732. interpolation='nearest',
  733. backend=self.backend)
  734. return results
  735. def __repr__(self) -> str:
  736. repr_str = self.__class__.__name__
  737. repr_str += f'(scale_factor={self.scale_factor}, '
  738. repr_str += f'backend={self.backend})'
  739. return repr_str
  740. @TRANSFORMS.register_module()
  741. class PhotoMetricDistortion(BaseTransform):
  742. """Apply photometric distortion to image sequentially, every transformation
  743. is applied with a probability of 0.5. The position of random contrast is in
  744. second or second to last.
  745. 1. random brightness
  746. 2. random contrast (mode 0)
  747. 3. convert color from BGR to HSV
  748. 4. random saturation
  749. 5. random hue
  750. 6. convert color from HSV to BGR
  751. 7. random contrast (mode 1)
  752. 8. randomly swap channels
  753. Required Keys:
  754. - img (np.uint8)
  755. Modified Keys:
  756. - img (np.float32)
  757. Args:
  758. brightness_delta (int): delta of brightness.
  759. contrast_range (sequence): range of contrast.
  760. saturation_range (sequence): range of saturation.
  761. hue_delta (int): delta of hue.
  762. """
  763. def __init__(self,
  764. brightness_delta: int = 32,
  765. contrast_range: Sequence[Number] = (0.5, 1.5),
  766. saturation_range: Sequence[Number] = (0.5, 1.5),
  767. hue_delta: int = 18) -> None:
  768. self.brightness_delta = brightness_delta
  769. self.contrast_lower, self.contrast_upper = contrast_range
  770. self.saturation_lower, self.saturation_upper = saturation_range
  771. self.hue_delta = hue_delta
  772. @cache_randomness
  773. def _random_flags(self) -> Sequence[Number]:
  774. mode = random.randint(2)
  775. brightness_flag = random.randint(2)
  776. contrast_flag = random.randint(2)
  777. saturation_flag = random.randint(2)
  778. hue_flag = random.randint(2)
  779. swap_flag = random.randint(2)
  780. delta_value = random.uniform(-self.brightness_delta,
  781. self.brightness_delta)
  782. alpha_value = random.uniform(self.contrast_lower, self.contrast_upper)
  783. saturation_value = random.uniform(self.saturation_lower,
  784. self.saturation_upper)
  785. hue_value = random.uniform(-self.hue_delta, self.hue_delta)
  786. swap_value = random.permutation(3)
  787. return (mode, brightness_flag, contrast_flag, saturation_flag,
  788. hue_flag, swap_flag, delta_value, alpha_value,
  789. saturation_value, hue_value, swap_value)
  790. def transform(self, results: dict) -> dict:
  791. """Transform function to perform photometric distortion on images.
  792. Args:
  793. results (dict): Result dict from loading pipeline.
  794. Returns:
  795. dict: Result dict with images distorted.
  796. """
  797. assert 'img' in results, '`img` is not found in results'
  798. img = results['img']
  799. img = img.astype(np.float32)
  800. (mode, brightness_flag, contrast_flag, saturation_flag, hue_flag,
  801. swap_flag, delta_value, alpha_value, saturation_value, hue_value,
  802. swap_value) = self._random_flags()
  803. # random brightness
  804. if brightness_flag:
  805. img += delta_value
  806. # mode == 0 --> do random contrast first
  807. # mode == 1 --> do random contrast last
  808. if mode == 1:
  809. if contrast_flag:
  810. img *= alpha_value
  811. # convert color from BGR to HSV
  812. img = mmcv.bgr2hsv(img)
  813. # random saturation
  814. if saturation_flag:
  815. img[..., 1] *= saturation_value
  816. # For image(type=float32), after convert bgr to hsv by opencv,
  817. # valid saturation value range is [0, 1]
  818. if saturation_value > 1:
  819. img[..., 1] = img[..., 1].clip(0, 1)
  820. # random hue
  821. if hue_flag:
  822. img[..., 0] += hue_value
  823. img[..., 0][img[..., 0] > 360] -= 360
  824. img[..., 0][img[..., 0] < 0] += 360
  825. # convert color from HSV to BGR
  826. img = mmcv.hsv2bgr(img)
  827. # random contrast
  828. if mode == 0:
  829. if contrast_flag:
  830. img *= alpha_value
  831. # randomly swap channels
  832. if swap_flag:
  833. img = img[..., swap_value]
  834. results['img'] = img
  835. return results
  836. def __repr__(self) -> str:
  837. repr_str = self.__class__.__name__
  838. repr_str += f'(brightness_delta={self.brightness_delta}, '
  839. repr_str += 'contrast_range='
  840. repr_str += f'{(self.contrast_lower, self.contrast_upper)}, '
  841. repr_str += 'saturation_range='
  842. repr_str += f'{(self.saturation_lower, self.saturation_upper)}, '
  843. repr_str += f'hue_delta={self.hue_delta})'
  844. return repr_str
  845. @TRANSFORMS.register_module()
  846. class Expand(BaseTransform):
  847. """Random expand the image & bboxes & masks & segmentation map.
  848. Randomly place the original image on a canvas of ``ratio`` x original image
  849. size filled with mean values. The ratio is in the range of ratio_range.
  850. Required Keys:
  851. - img
  852. - img_shape
  853. - gt_bboxes (BaseBoxes[torch.float32]) (optional)
  854. - gt_masks (BitmapMasks | PolygonMasks) (optional)
  855. - gt_seg_map (np.uint8) (optional)
  856. Modified Keys:
  857. - img
  858. - img_shape
  859. - gt_bboxes
  860. - gt_masks
  861. - gt_seg_map
  862. Args:
  863. mean (sequence): mean value of dataset.
  864. to_rgb (bool): if need to convert the order of mean to align with RGB.
  865. ratio_range (sequence)): range of expand ratio.
  866. seg_ignore_label (int): label of ignore segmentation map.
  867. prob (float): probability of applying this transformation
  868. """
  869. def __init__(self,
  870. mean: Sequence[Number] = (0, 0, 0),
  871. to_rgb: bool = True,
  872. ratio_range: Sequence[Number] = (1, 4),
  873. seg_ignore_label: int = None,
  874. prob: float = 0.5) -> None:
  875. self.to_rgb = to_rgb
  876. self.ratio_range = ratio_range
  877. if to_rgb:
  878. self.mean = mean[::-1]
  879. else:
  880. self.mean = mean
  881. self.min_ratio, self.max_ratio = ratio_range
  882. self.seg_ignore_label = seg_ignore_label
  883. self.prob = prob
  884. @cache_randomness
  885. def _random_prob(self) -> float:
  886. return random.uniform(0, 1)
  887. @cache_randomness
  888. def _random_ratio(self) -> float:
  889. return random.uniform(self.min_ratio, self.max_ratio)
  890. @cache_randomness
  891. def _random_left_top(self, ratio: float, h: int,
  892. w: int) -> Tuple[int, int]:
  893. left = int(random.uniform(0, w * ratio - w))
  894. top = int(random.uniform(0, h * ratio - h))
  895. return left, top
  896. @autocast_box_type()
  897. def transform(self, results: dict) -> dict:
  898. """Transform function to expand images, bounding boxes, masks,
  899. segmentation map.
  900. Args:
  901. results (dict): Result dict from loading pipeline.
  902. Returns:
  903. dict: Result dict with images, bounding boxes, masks, segmentation
  904. map expanded.
  905. """
  906. if self._random_prob() > self.prob:
  907. return results
  908. assert 'img' in results, '`img` is not found in results'
  909. img = results['img']
  910. h, w, c = img.shape
  911. ratio = self._random_ratio()
  912. # speedup expand when meets large image
  913. if np.all(self.mean == self.mean[0]):
  914. expand_img = np.empty((int(h * ratio), int(w * ratio), c),
  915. img.dtype)
  916. expand_img.fill(self.mean[0])
  917. else:
  918. expand_img = np.full((int(h * ratio), int(w * ratio), c),
  919. self.mean,
  920. dtype=img.dtype)
  921. left, top = self._random_left_top(ratio, h, w)
  922. expand_img[top:top + h, left:left + w] = img
  923. results['img'] = expand_img
  924. results['img_shape'] = expand_img.shape[:2]
  925. # expand bboxes
  926. if results.get('gt_bboxes', None) is not None:
  927. results['gt_bboxes'].translate_([left, top])
  928. # expand masks
  929. if results.get('gt_masks', None) is not None:
  930. results['gt_masks'] = results['gt_masks'].expand(
  931. int(h * ratio), int(w * ratio), top, left)
  932. # expand segmentation map
  933. if results.get('gt_seg_map', None) is not None:
  934. gt_seg = results['gt_seg_map']
  935. expand_gt_seg = np.full((int(h * ratio), int(w * ratio)),
  936. self.seg_ignore_label,
  937. dtype=gt_seg.dtype)
  938. expand_gt_seg[top:top + h, left:left + w] = gt_seg
  939. results['gt_seg_map'] = expand_gt_seg
  940. return results
  941. def __repr__(self) -> str:
  942. repr_str = self.__class__.__name__
  943. repr_str += f'(mean={self.mean}, to_rgb={self.to_rgb}, '
  944. repr_str += f'ratio_range={self.ratio_range}, '
  945. repr_str += f'seg_ignore_label={self.seg_ignore_label}, '
  946. repr_str += f'prob={self.prob})'
  947. return repr_str
  948. @TRANSFORMS.register_module()
  949. class MinIoURandomCrop(BaseTransform):
  950. """Random crop the image & bboxes & masks & segmentation map, the cropped
  951. patches have minimum IoU requirement with original image & bboxes & masks.
  952. & segmentation map, the IoU threshold is randomly selected from min_ious.
  953. Required Keys:
  954. - img
  955. - img_shape
  956. - gt_bboxes (BaseBoxes[torch.float32]) (optional)
  957. - gt_bboxes_labels (np.int64) (optional)
  958. - gt_masks (BitmapMasks | PolygonMasks) (optional)
  959. - gt_ignore_flags (bool) (optional)
  960. - gt_seg_map (np.uint8) (optional)
  961. Modified Keys:
  962. - img
  963. - img_shape
  964. - gt_bboxes
  965. - gt_bboxes_labels
  966. - gt_masks
  967. - gt_ignore_flags
  968. - gt_seg_map
  969. Args:
  970. min_ious (Sequence[float]): minimum IoU threshold for all intersections
  971. with bounding boxes.
  972. min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w,
  973. where a >= min_crop_size).
  974. bbox_clip_border (bool, optional): Whether clip the objects outside
  975. the border of the image. Defaults to True.
  976. """
  977. def __init__(self,
  978. min_ious: Sequence[float] = (0.1, 0.3, 0.5, 0.7, 0.9),
  979. min_crop_size: float = 0.3,
  980. bbox_clip_border: bool = True) -> None:
  981. self.min_ious = min_ious
  982. self.sample_mode = (1, *min_ious, 0)
  983. self.min_crop_size = min_crop_size
  984. self.bbox_clip_border = bbox_clip_border
  985. @cache_randomness
  986. def _random_mode(self) -> Number:
  987. return random.choice(self.sample_mode)
  988. @autocast_box_type()
  989. def transform(self, results: dict) -> dict:
  990. """Transform function to crop images and bounding boxes with minimum
  991. IoU constraint.
  992. Args:
  993. results (dict): Result dict from loading pipeline.
  994. Returns:
  995. dict: Result dict with images and bounding boxes cropped, \
  996. 'img_shape' key is updated.
  997. """
  998. assert 'img' in results, '`img` is not found in results'
  999. assert 'gt_bboxes' in results, '`gt_bboxes` is not found in results'
  1000. img = results['img']
  1001. boxes = results['gt_bboxes']
  1002. h, w, c = img.shape
  1003. while True:
  1004. mode = self._random_mode()
  1005. self.mode = mode
  1006. if mode == 1:
  1007. return results
  1008. min_iou = self.mode
  1009. for i in range(50):
  1010. new_w = random.uniform(self.min_crop_size * w, w)
  1011. new_h = random.uniform(self.min_crop_size * h, h)
  1012. # h / w in [0.5, 2]
  1013. if new_h / new_w < 0.5 or new_h / new_w > 2:
  1014. continue
  1015. left = random.uniform(w - new_w)
  1016. top = random.uniform(h - new_h)
  1017. patch = np.array(
  1018. (int(left), int(top), int(left + new_w), int(top + new_h)))
  1019. # Line or point crop is not allowed
  1020. if patch[2] == patch[0] or patch[3] == patch[1]:
  1021. continue
  1022. overlaps = boxes.overlaps(
  1023. HorizontalBoxes(patch.reshape(-1, 4).astype(np.float32)),
  1024. boxes).numpy().reshape(-1)
  1025. if len(overlaps) > 0 and overlaps.min() < min_iou:
  1026. continue
  1027. # center of boxes should inside the crop img
  1028. # only adjust boxes and instance masks when the gt is not empty
  1029. if len(overlaps) > 0:
  1030. # adjust boxes
  1031. def is_center_of_bboxes_in_patch(boxes, patch):
  1032. centers = boxes.centers.numpy()
  1033. mask = ((centers[:, 0] > patch[0]) *
  1034. (centers[:, 1] > patch[1]) *
  1035. (centers[:, 0] < patch[2]) *
  1036. (centers[:, 1] < patch[3]))
  1037. return mask
  1038. mask = is_center_of_bboxes_in_patch(boxes, patch)
  1039. if not mask.any():
  1040. continue
  1041. if results.get('gt_bboxes', None) is not None:
  1042. boxes = results['gt_bboxes']
  1043. mask = is_center_of_bboxes_in_patch(boxes, patch)
  1044. boxes = boxes[mask]
  1045. boxes.translate_([-patch[0], -patch[1]])
  1046. if self.bbox_clip_border:
  1047. boxes.clip_(
  1048. [patch[3] - patch[1], patch[2] - patch[0]])
  1049. results['gt_bboxes'] = boxes
  1050. # ignore_flags
  1051. if results.get('gt_ignore_flags', None) is not None:
  1052. results['gt_ignore_flags'] = \
  1053. results['gt_ignore_flags'][mask]
  1054. # labels
  1055. if results.get('gt_bboxes_labels', None) is not None:
  1056. results['gt_bboxes_labels'] = results[
  1057. 'gt_bboxes_labels'][mask]
  1058. # mask fields
  1059. if results.get('gt_masks', None) is not None:
  1060. results['gt_masks'] = results['gt_masks'][
  1061. mask.nonzero()[0]].crop(patch)
  1062. # adjust the img no matter whether the gt is empty before crop
  1063. img = img[patch[1]:patch[3], patch[0]:patch[2]]
  1064. results['img'] = img
  1065. results['img_shape'] = img.shape[:2]
  1066. # seg fields
  1067. if results.get('gt_seg_map', None) is not None:
  1068. results['gt_seg_map'] = results['gt_seg_map'][
  1069. patch[1]:patch[3], patch[0]:patch[2]]
  1070. return results
  1071. def __repr__(self) -> str:
  1072. repr_str = self.__class__.__name__
  1073. repr_str += f'(min_ious={self.min_ious}, '
  1074. repr_str += f'min_crop_size={self.min_crop_size}, '
  1075. repr_str += f'bbox_clip_border={self.bbox_clip_border})'
  1076. return repr_str
  1077. @TRANSFORMS.register_module()
  1078. class Corrupt(BaseTransform):
  1079. """Corruption augmentation.
  1080. Corruption transforms implemented based on
  1081. `imagecorruptions <https://github.com/bethgelab/imagecorruptions>`_.
  1082. Required Keys:
  1083. - img (np.uint8)
  1084. Modified Keys:
  1085. - img (np.uint8)
  1086. Args:
  1087. corruption (str): Corruption name.
  1088. severity (int): The severity of corruption. Defaults to 1.
  1089. """
  1090. def __init__(self, corruption: str, severity: int = 1) -> None:
  1091. self.corruption = corruption
  1092. self.severity = severity
  1093. def transform(self, results: dict) -> dict:
  1094. """Call function to corrupt image.
  1095. Args:
  1096. results (dict): Result dict from loading pipeline.
  1097. Returns:
  1098. dict: Result dict with images corrupted.
  1099. """
  1100. if corrupt is None:
  1101. raise RuntimeError('imagecorruptions is not installed')
  1102. results['img'] = corrupt(
  1103. results['img'].astype(np.uint8),
  1104. corruption_name=self.corruption,
  1105. severity=self.severity)
  1106. return results
  1107. def __repr__(self) -> str:
  1108. repr_str = self.__class__.__name__
  1109. repr_str += f'(corruption={self.corruption}, '
  1110. repr_str += f'severity={self.severity})'
  1111. return repr_str
  1112. @TRANSFORMS.register_module()
  1113. @avoid_cache_randomness
  1114. class Albu(BaseTransform):
  1115. """Albumentation augmentation.
  1116. Adds custom transformations from Albumentations library.
  1117. Please, visit `https://albumentations.readthedocs.io`
  1118. to get more information.
  1119. Required Keys:
  1120. - img (np.uint8)
  1121. - gt_bboxes (HorizontalBoxes[torch.float32]) (optional)
  1122. - gt_masks (BitmapMasks | PolygonMasks) (optional)
  1123. Modified Keys:
  1124. - img (np.uint8)
  1125. - gt_bboxes (HorizontalBoxes[torch.float32]) (optional)
  1126. - gt_masks (BitmapMasks | PolygonMasks) (optional)
  1127. - img_shape (tuple)
  1128. An example of ``transforms`` is as followed:
  1129. .. code-block::
  1130. [
  1131. dict(
  1132. type='ShiftScaleRotate',
  1133. shift_limit=0.0625,
  1134. scale_limit=0.0,
  1135. rotate_limit=0,
  1136. interpolation=1,
  1137. p=0.5),
  1138. dict(
  1139. type='RandomBrightnessContrast',
  1140. brightness_limit=[0.1, 0.3],
  1141. contrast_limit=[0.1, 0.3],
  1142. p=0.2),
  1143. dict(type='ChannelShuffle', p=0.1),
  1144. dict(
  1145. type='OneOf',
  1146. transforms=[
  1147. dict(type='Blur', blur_limit=3, p=1.0),
  1148. dict(type='MedianBlur', blur_limit=3, p=1.0)
  1149. ],
  1150. p=0.1),
  1151. ]
  1152. Args:
  1153. transforms (list[dict]): A list of albu transformations
  1154. bbox_params (dict, optional): Bbox_params for albumentation `Compose`
  1155. keymap (dict, optional): Contains
  1156. {'input key':'albumentation-style key'}
  1157. skip_img_without_anno (bool): Whether to skip the image if no ann left
  1158. after aug. Defaults to False.
  1159. """
  1160. def __init__(self,
  1161. transforms: List[dict],
  1162. bbox_params: Optional[dict] = None,
  1163. keymap: Optional[dict] = None,
  1164. skip_img_without_anno: bool = False) -> None:
  1165. if Compose is None:
  1166. raise RuntimeError('albumentations is not installed')
  1167. # Args will be modified later, copying it will be safer
  1168. transforms = copy.deepcopy(transforms)
  1169. if bbox_params is not None:
  1170. bbox_params = copy.deepcopy(bbox_params)
  1171. if keymap is not None:
  1172. keymap = copy.deepcopy(keymap)
  1173. self.transforms = transforms
  1174. self.filter_lost_elements = False
  1175. self.skip_img_without_anno = skip_img_without_anno
  1176. # A simple workaround to remove masks without boxes
  1177. if (isinstance(bbox_params, dict) and 'label_fields' in bbox_params
  1178. and 'filter_lost_elements' in bbox_params):
  1179. self.filter_lost_elements = True
  1180. self.origin_label_fields = bbox_params['label_fields']
  1181. bbox_params['label_fields'] = ['idx_mapper']
  1182. del bbox_params['filter_lost_elements']
  1183. self.bbox_params = (
  1184. self.albu_builder(bbox_params) if bbox_params else None)
  1185. self.aug = Compose([self.albu_builder(t) for t in self.transforms],
  1186. bbox_params=self.bbox_params)
  1187. if not keymap:
  1188. self.keymap_to_albu = {
  1189. 'img': 'image',
  1190. 'gt_masks': 'masks',
  1191. 'gt_bboxes': 'bboxes'
  1192. }
  1193. else:
  1194. self.keymap_to_albu = keymap
  1195. self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}
  1196. def albu_builder(self, cfg: dict) -> albumentations:
  1197. """Import a module from albumentations.
  1198. It inherits some of :func:`build_from_cfg` logic.
  1199. Args:
  1200. cfg (dict): Config dict. It should at least contain the key "type".
  1201. Returns:
  1202. obj: The constructed object.
  1203. """
  1204. assert isinstance(cfg, dict) and 'type' in cfg
  1205. args = cfg.copy()
  1206. obj_type = args.pop('type')
  1207. if is_str(obj_type):
  1208. if albumentations is None:
  1209. raise RuntimeError('albumentations is not installed')
  1210. obj_cls = getattr(albumentations, obj_type)
  1211. elif inspect.isclass(obj_type):
  1212. obj_cls = obj_type
  1213. else:
  1214. raise TypeError(
  1215. f'type must be a str or valid type, but got {type(obj_type)}')
  1216. if 'transforms' in args:
  1217. args['transforms'] = [
  1218. self.albu_builder(transform)
  1219. for transform in args['transforms']
  1220. ]
  1221. return obj_cls(**args)
  1222. @staticmethod
  1223. def mapper(d: dict, keymap: dict) -> dict:
  1224. """Dictionary mapper. Renames keys according to keymap provided.
  1225. Args:
  1226. d (dict): old dict
  1227. keymap (dict): {'old_key':'new_key'}
  1228. Returns:
  1229. dict: new dict.
  1230. """
  1231. updated_dict = {}
  1232. for k, v in zip(d.keys(), d.values()):
  1233. new_k = keymap.get(k, k)
  1234. updated_dict[new_k] = d[k]
  1235. return updated_dict
  1236. @autocast_box_type()
  1237. def transform(self, results: dict) -> Union[dict, None]:
  1238. """Transform function of Albu."""
  1239. # TODO: gt_seg_map is not currently supported
  1240. # dict to albumentations format
  1241. results = self.mapper(results, self.keymap_to_albu)
  1242. results, ori_masks = self._preprocess_results(results)
  1243. results = self.aug(**results)
  1244. results = self._postprocess_results(results, ori_masks)
  1245. if results is None:
  1246. return None
  1247. # back to the original format
  1248. results = self.mapper(results, self.keymap_back)
  1249. results['img_shape'] = results['img'].shape[:2]
  1250. return results
  1251. def _preprocess_results(self, results: dict) -> tuple:
  1252. """Pre-processing results to facilitate the use of Albu."""
  1253. if 'bboxes' in results:
  1254. # to list of boxes
  1255. if not isinstance(results['bboxes'], HorizontalBoxes):
  1256. raise NotImplementedError(
  1257. 'Albu only supports horizontal boxes now')
  1258. bboxes = results['bboxes'].numpy()
  1259. results['bboxes'] = [x for x in bboxes]
  1260. # add pseudo-field for filtration
  1261. if self.filter_lost_elements:
  1262. results['idx_mapper'] = np.arange(len(results['bboxes']))
  1263. # TODO: Support mask structure in albu
  1264. ori_masks = None
  1265. if 'masks' in results:
  1266. if isinstance(results['masks'], PolygonMasks):
  1267. raise NotImplementedError(
  1268. 'Albu only supports BitMap masks now')
  1269. ori_masks = results['masks']
  1270. if albumentations.__version__ < '0.5':
  1271. results['masks'] = results['masks'].masks
  1272. else:
  1273. results['masks'] = [mask for mask in results['masks'].masks]
  1274. return results, ori_masks
  1275. def _postprocess_results(
  1276. self,
  1277. results: dict,
  1278. ori_masks: Optional[Union[BitmapMasks,
  1279. PolygonMasks]] = None) -> dict:
  1280. """Post-processing Albu output."""
  1281. # albumentations may return np.array or list on different versions
  1282. if 'gt_bboxes_labels' in results and isinstance(
  1283. results['gt_bboxes_labels'], list):
  1284. results['gt_bboxes_labels'] = np.array(
  1285. results['gt_bboxes_labels'], dtype=np.int64)
  1286. if 'gt_ignore_flags' in results and isinstance(
  1287. results['gt_ignore_flags'], list):
  1288. results['gt_ignore_flags'] = np.array(
  1289. results['gt_ignore_flags'], dtype=bool)
  1290. if 'bboxes' in results:
  1291. if isinstance(results['bboxes'], list):
  1292. results['bboxes'] = np.array(
  1293. results['bboxes'], dtype=np.float32)
  1294. results['bboxes'] = results['bboxes'].reshape(-1, 4)
  1295. results['bboxes'] = HorizontalBoxes(results['bboxes'])
  1296. # filter label_fields
  1297. if self.filter_lost_elements:
  1298. for label in self.origin_label_fields:
  1299. results[label] = np.array(
  1300. [results[label][i] for i in results['idx_mapper']])
  1301. if 'masks' in results:
  1302. assert ori_masks is not None
  1303. results['masks'] = np.array(
  1304. [results['masks'][i] for i in results['idx_mapper']])
  1305. results['masks'] = ori_masks.__class__(
  1306. results['masks'], ori_masks.height, ori_masks.width)
  1307. if (not len(results['idx_mapper'])
  1308. and self.skip_img_without_anno):
  1309. return None
  1310. elif 'masks' in results:
  1311. results['masks'] = ori_masks.__class__(results['masks'],
  1312. ori_masks.height,
  1313. ori_masks.width)
  1314. return results
  1315. def __repr__(self) -> str:
  1316. repr_str = self.__class__.__name__ + f'(transforms={self.transforms})'
  1317. return repr_str
  1318. @TRANSFORMS.register_module()
  1319. @avoid_cache_randomness
  1320. class RandomCenterCropPad(BaseTransform):
  1321. """Random center crop and random around padding for CornerNet.
  1322. This operation generates randomly cropped image from the original image and
  1323. pads it simultaneously. Different from :class:`RandomCrop`, the output
  1324. shape may not equal to ``crop_size`` strictly. We choose a random value
  1325. from ``ratios`` and the output shape could be larger or smaller than
  1326. ``crop_size``. The padding operation is also different from :class:`Pad`,
  1327. here we use around padding instead of right-bottom padding.
  1328. The relation between output image (padding image) and original image:
  1329. .. code:: text
  1330. output image
  1331. +----------------------------+
  1332. | padded area |
  1333. +------|----------------------------|----------+
  1334. | | cropped area | |
  1335. | | +---------------+ | |
  1336. | | | . center | | | original image
  1337. | | | range | | |
  1338. | | +---------------+ | |
  1339. +------|----------------------------|----------+
  1340. | padded area |
  1341. +----------------------------+
  1342. There are 5 main areas in the figure:
  1343. - output image: output image of this operation, also called padding
  1344. image in following instruction.
  1345. - original image: input image of this operation.
  1346. - padded area: non-intersect area of output image and original image.
  1347. - cropped area: the overlap of output image and original image.
  1348. - center range: a smaller area where random center chosen from.
  1349. center range is computed by ``border`` and original image's shape
  1350. to avoid our random center is too close to original image's border.
  1351. Also this operation act differently in train and test mode, the summary
  1352. pipeline is listed below.
  1353. Train pipeline:
  1354. 1. Choose a ``random_ratio`` from ``ratios``, the shape of padding image
  1355. will be ``random_ratio * crop_size``.
  1356. 2. Choose a ``random_center`` in center range.
  1357. 3. Generate padding image with center matches the ``random_center``.
  1358. 4. Initialize the padding image with pixel value equals to ``mean``.
  1359. 5. Copy the cropped area to padding image.
  1360. 6. Refine annotations.
  1361. Test pipeline:
  1362. 1. Compute output shape according to ``test_pad_mode``.
  1363. 2. Generate padding image with center matches the original image
  1364. center.
  1365. 3. Initialize the padding image with pixel value equals to ``mean``.
  1366. 4. Copy the ``cropped area`` to padding image.
  1367. Required Keys:
  1368. - img (np.float32)
  1369. - img_shape (tuple)
  1370. - gt_bboxes (BaseBoxes[torch.float32]) (optional)
  1371. - gt_bboxes_labels (np.int64) (optional)
  1372. - gt_ignore_flags (bool) (optional)
  1373. Modified Keys:
  1374. - img (np.float32)
  1375. - img_shape (tuple)
  1376. - gt_bboxes (BaseBoxes[torch.float32]) (optional)
  1377. - gt_bboxes_labels (np.int64) (optional)
  1378. - gt_ignore_flags (bool) (optional)
  1379. Args:
  1380. crop_size (tuple, optional): expected size after crop, final size will
  1381. computed according to ratio. Requires (width, height)
  1382. in train mode, and None in test mode.
  1383. ratios (tuple, optional): random select a ratio from tuple and crop
  1384. image to (crop_size[0] * ratio) * (crop_size[1] * ratio).
  1385. Only available in train mode. Defaults to (0.9, 1.0, 1.1).
  1386. border (int, optional): max distance from center select area to image
  1387. border. Only available in train mode. Defaults to 128.
  1388. mean (sequence, optional): Mean values of 3 channels.
  1389. std (sequence, optional): Std values of 3 channels.
  1390. to_rgb (bool, optional): Whether to convert the image from BGR to RGB.
  1391. test_mode (bool): whether involve random variables in transform.
  1392. In train mode, crop_size is fixed, center coords and ratio is
  1393. random selected from predefined lists. In test mode, crop_size
  1394. is image's original shape, center coords and ratio is fixed.
  1395. Defaults to False.
  1396. test_pad_mode (tuple, optional): padding method and padding shape
  1397. value, only available in test mode. Default is using
  1398. 'logical_or' with 127 as padding shape value.
  1399. - 'logical_or': final_shape = input_shape | padding_shape_value
  1400. - 'size_divisor': final_shape = int(
  1401. ceil(input_shape / padding_shape_value) * padding_shape_value)
  1402. Defaults to ('logical_or', 127).
  1403. test_pad_add_pix (int): Extra padding pixel in test mode.
  1404. Defaults to 0.
  1405. bbox_clip_border (bool): Whether clip the objects outside
  1406. the border of the image. Defaults to True.
  1407. """
  1408. def __init__(self,
  1409. crop_size: Optional[tuple] = None,
  1410. ratios: Optional[tuple] = (0.9, 1.0, 1.1),
  1411. border: Optional[int] = 128,
  1412. mean: Optional[Sequence] = None,
  1413. std: Optional[Sequence] = None,
  1414. to_rgb: Optional[bool] = None,
  1415. test_mode: bool = False,
  1416. test_pad_mode: Optional[tuple] = ('logical_or', 127),
  1417. test_pad_add_pix: int = 0,
  1418. bbox_clip_border: bool = True) -> None:
  1419. if test_mode:
  1420. assert crop_size is None, 'crop_size must be None in test mode'
  1421. assert ratios is None, 'ratios must be None in test mode'
  1422. assert border is None, 'border must be None in test mode'
  1423. assert isinstance(test_pad_mode, (list, tuple))
  1424. assert test_pad_mode[0] in ['logical_or', 'size_divisor']
  1425. else:
  1426. assert isinstance(crop_size, (list, tuple))
  1427. assert crop_size[0] > 0 and crop_size[1] > 0, (
  1428. 'crop_size must > 0 in train mode')
  1429. assert isinstance(ratios, (list, tuple))
  1430. assert test_pad_mode is None, (
  1431. 'test_pad_mode must be None in train mode')
  1432. self.crop_size = crop_size
  1433. self.ratios = ratios
  1434. self.border = border
  1435. # We do not set default value to mean, std and to_rgb because these
  1436. # hyper-parameters are easy to forget but could affect the performance.
  1437. # Please use the same setting as Normalize for performance assurance.
  1438. assert mean is not None and std is not None and to_rgb is not None
  1439. self.to_rgb = to_rgb
  1440. self.input_mean = mean
  1441. self.input_std = std
  1442. if to_rgb:
  1443. self.mean = mean[::-1]
  1444. self.std = std[::-1]
  1445. else:
  1446. self.mean = mean
  1447. self.std = std
  1448. self.test_mode = test_mode
  1449. self.test_pad_mode = test_pad_mode
  1450. self.test_pad_add_pix = test_pad_add_pix
  1451. self.bbox_clip_border = bbox_clip_border
  1452. def _get_border(self, border, size):
  1453. """Get final border for the target size.
  1454. This function generates a ``final_border`` according to image's shape.
  1455. The area between ``final_border`` and ``size - final_border`` is the
  1456. ``center range``. We randomly choose center from the ``center range``
  1457. to avoid our random center is too close to original image's border.
  1458. Also ``center range`` should be larger than 0.
  1459. Args:
  1460. border (int): The initial border, default is 128.
  1461. size (int): The width or height of original image.
  1462. Returns:
  1463. int: The final border.
  1464. """
  1465. k = 2 * border / size
  1466. i = pow(2, np.ceil(np.log2(np.ceil(k))) + (k == int(k)))
  1467. return border // i
  1468. def _filter_boxes(self, patch, boxes):
  1469. """Check whether the center of each box is in the patch.
  1470. Args:
  1471. patch (list[int]): The cropped area, [left, top, right, bottom].
  1472. boxes (numpy array, (N x 4)): Ground truth boxes.
  1473. Returns:
  1474. mask (numpy array, (N,)): Each box is inside or outside the patch.
  1475. """
  1476. center = boxes.centers.numpy()
  1477. mask = (center[:, 0] > patch[0]) * (center[:, 1] > patch[1]) * (
  1478. center[:, 0] < patch[2]) * (
  1479. center[:, 1] < patch[3])
  1480. return mask
  1481. def _crop_image_and_paste(self, image, center, size):
  1482. """Crop image with a given center and size, then paste the cropped
  1483. image to a blank image with two centers align.
  1484. This function is equivalent to generating a blank image with ``size``
  1485. as its shape. Then cover it on the original image with two centers (
  1486. the center of blank image and the random center of original image)
  1487. aligned. The overlap area is paste from the original image and the
  1488. outside area is filled with ``mean pixel``.
  1489. Args:
  1490. image (np array, H x W x C): Original image.
  1491. center (list[int]): Target crop center coord.
  1492. size (list[int]): Target crop size. [target_h, target_w]
  1493. Returns:
  1494. cropped_img (np array, target_h x target_w x C): Cropped image.
  1495. border (np array, 4): The distance of four border of
  1496. ``cropped_img`` to the original image area, [top, bottom,
  1497. left, right]
  1498. patch (list[int]): The cropped area, [left, top, right, bottom].
  1499. """
  1500. center_y, center_x = center
  1501. target_h, target_w = size
  1502. img_h, img_w, img_c = image.shape
  1503. x0 = max(0, center_x - target_w // 2)
  1504. x1 = min(center_x + target_w // 2, img_w)
  1505. y0 = max(0, center_y - target_h // 2)
  1506. y1 = min(center_y + target_h // 2, img_h)
  1507. patch = np.array((int(x0), int(y0), int(x1), int(y1)))
  1508. left, right = center_x - x0, x1 - center_x
  1509. top, bottom = center_y - y0, y1 - center_y
  1510. cropped_center_y, cropped_center_x = target_h // 2, target_w // 2
  1511. cropped_img = np.zeros((target_h, target_w, img_c), dtype=image.dtype)
  1512. for i in range(img_c):
  1513. cropped_img[:, :, i] += self.mean[i]
  1514. y_slice = slice(cropped_center_y - top, cropped_center_y + bottom)
  1515. x_slice = slice(cropped_center_x - left, cropped_center_x + right)
  1516. cropped_img[y_slice, x_slice, :] = image[y0:y1, x0:x1, :]
  1517. border = np.array([
  1518. cropped_center_y - top, cropped_center_y + bottom,
  1519. cropped_center_x - left, cropped_center_x + right
  1520. ],
  1521. dtype=np.float32)
  1522. return cropped_img, border, patch
  1523. def _train_aug(self, results):
  1524. """Random crop and around padding the original image.
  1525. Args:
  1526. results (dict): Image infomations in the augment pipeline.
  1527. Returns:
  1528. results (dict): The updated dict.
  1529. """
  1530. img = results['img']
  1531. h, w, c = img.shape
  1532. gt_bboxes = results['gt_bboxes']
  1533. while True:
  1534. scale = random.choice(self.ratios)
  1535. new_h = int(self.crop_size[1] * scale)
  1536. new_w = int(self.crop_size[0] * scale)
  1537. h_border = self._get_border(self.border, h)
  1538. w_border = self._get_border(self.border, w)
  1539. for i in range(50):
  1540. center_x = random.randint(low=w_border, high=w - w_border)
  1541. center_y = random.randint(low=h_border, high=h - h_border)
  1542. cropped_img, border, patch = self._crop_image_and_paste(
  1543. img, [center_y, center_x], [new_h, new_w])
  1544. if len(gt_bboxes) == 0:
  1545. results['img'] = cropped_img
  1546. results['img_shape'] = cropped_img.shape[:2]
  1547. return results
  1548. # if image do not have valid bbox, any crop patch is valid.
  1549. mask = self._filter_boxes(patch, gt_bboxes)
  1550. if not mask.any():
  1551. continue
  1552. results['img'] = cropped_img
  1553. results['img_shape'] = cropped_img.shape[:2]
  1554. x0, y0, x1, y1 = patch
  1555. left_w, top_h = center_x - x0, center_y - y0
  1556. cropped_center_x, cropped_center_y = new_w // 2, new_h // 2
  1557. # crop bboxes accordingly and clip to the image boundary
  1558. gt_bboxes = gt_bboxes[mask]
  1559. gt_bboxes.translate_([
  1560. cropped_center_x - left_w - x0,
  1561. cropped_center_y - top_h - y0
  1562. ])
  1563. if self.bbox_clip_border:
  1564. gt_bboxes.clip_([new_h, new_w])
  1565. keep = gt_bboxes.is_inside([new_h, new_w]).numpy()
  1566. gt_bboxes = gt_bboxes[keep]
  1567. results['gt_bboxes'] = gt_bboxes
  1568. # ignore_flags
  1569. if results.get('gt_ignore_flags', None) is not None:
  1570. gt_ignore_flags = results['gt_ignore_flags'][mask]
  1571. results['gt_ignore_flags'] = \
  1572. gt_ignore_flags[keep]
  1573. # labels
  1574. if results.get('gt_bboxes_labels', None) is not None:
  1575. gt_labels = results['gt_bboxes_labels'][mask]
  1576. results['gt_bboxes_labels'] = gt_labels[keep]
  1577. if 'gt_masks' in results or 'gt_seg_map' in results:
  1578. raise NotImplementedError(
  1579. 'RandomCenterCropPad only supports bbox.')
  1580. return results
  1581. def _test_aug(self, results):
  1582. """Around padding the original image without cropping.
  1583. The padding mode and value are from ``test_pad_mode``.
  1584. Args:
  1585. results (dict): Image infomations in the augment pipeline.
  1586. Returns:
  1587. results (dict): The updated dict.
  1588. """
  1589. img = results['img']
  1590. h, w, c = img.shape
  1591. if self.test_pad_mode[0] in ['logical_or']:
  1592. # self.test_pad_add_pix is only used for centernet
  1593. target_h = (h | self.test_pad_mode[1]) + self.test_pad_add_pix
  1594. target_w = (w | self.test_pad_mode[1]) + self.test_pad_add_pix
  1595. elif self.test_pad_mode[0] in ['size_divisor']:
  1596. divisor = self.test_pad_mode[1]
  1597. target_h = int(np.ceil(h / divisor)) * divisor
  1598. target_w = int(np.ceil(w / divisor)) * divisor
  1599. else:
  1600. raise NotImplementedError(
  1601. 'RandomCenterCropPad only support two testing pad mode:'
  1602. 'logical-or and size_divisor.')
  1603. cropped_img, border, _ = self._crop_image_and_paste(
  1604. img, [h // 2, w // 2], [target_h, target_w])
  1605. results['img'] = cropped_img
  1606. results['img_shape'] = cropped_img.shape[:2]
  1607. results['border'] = border
  1608. return results
  1609. @autocast_box_type()
  1610. def transform(self, results: dict) -> dict:
  1611. img = results['img']
  1612. assert img.dtype == np.float32, (
  1613. 'RandomCenterCropPad needs the input image of dtype np.float32,'
  1614. ' please set "to_float32=True" in "LoadImageFromFile" pipeline')
  1615. h, w, c = img.shape
  1616. assert c == len(self.mean)
  1617. if self.test_mode:
  1618. return self._test_aug(results)
  1619. else:
  1620. return self._train_aug(results)
  1621. def __repr__(self):
  1622. repr_str = self.__class__.__name__
  1623. repr_str += f'(crop_size={self.crop_size}, '
  1624. repr_str += f'ratios={self.ratios}, '
  1625. repr_str += f'border={self.border}, '
  1626. repr_str += f'mean={self.input_mean}, '
  1627. repr_str += f'std={self.input_std}, '
  1628. repr_str += f'to_rgb={self.to_rgb}, '
  1629. repr_str += f'test_mode={self.test_mode}, '
  1630. repr_str += f'test_pad_mode={self.test_pad_mode}, '
  1631. repr_str += f'bbox_clip_border={self.bbox_clip_border})'
  1632. return repr_str
  1633. @TRANSFORMS.register_module()
  1634. class CutOut(BaseTransform):
  1635. """CutOut operation.
  1636. Randomly drop some regions of image used in
  1637. `Cutout <https://arxiv.org/abs/1708.04552>`_.
  1638. Required Keys:
  1639. - img
  1640. Modified Keys:
  1641. - img
  1642. Args:
  1643. n_holes (int or tuple[int, int]): Number of regions to be dropped.
  1644. If it is given as a list, number of holes will be randomly
  1645. selected from the closed interval [``n_holes[0]``, ``n_holes[1]``].
  1646. cutout_shape (tuple[int, int] or list[tuple[int, int]], optional):
  1647. The candidate shape of dropped regions. It can be
  1648. ``tuple[int, int]`` to use a fixed cutout shape, or
  1649. ``list[tuple[int, int]]`` to randomly choose shape
  1650. from the list. Defaults to None.
  1651. cutout_ratio (tuple[float, float] or list[tuple[float, float]],
  1652. optional): The candidate ratio of dropped regions. It can be
  1653. ``tuple[float, float]`` to use a fixed ratio or
  1654. ``list[tuple[float, float]]`` to randomly choose ratio
  1655. from the list. Please note that ``cutout_shape`` and
  1656. ``cutout_ratio`` cannot be both given at the same time.
  1657. Defaults to None.
  1658. fill_in (tuple[float, float, float] or tuple[int, int, int]): The value
  1659. of pixel to fill in the dropped regions. Defaults to (0, 0, 0).
  1660. """
  1661. def __init__(
  1662. self,
  1663. n_holes: Union[int, Tuple[int, int]],
  1664. cutout_shape: Optional[Union[Tuple[int, int],
  1665. List[Tuple[int, int]]]] = None,
  1666. cutout_ratio: Optional[Union[Tuple[float, float],
  1667. List[Tuple[float, float]]]] = None,
  1668. fill_in: Union[Tuple[float, float, float], Tuple[int, int,
  1669. int]] = (0, 0, 0)
  1670. ) -> None:
  1671. assert (cutout_shape is None) ^ (cutout_ratio is None), \
  1672. 'Either cutout_shape or cutout_ratio should be specified.'
  1673. assert (isinstance(cutout_shape, (list, tuple))
  1674. or isinstance(cutout_ratio, (list, tuple)))
  1675. if isinstance(n_holes, tuple):
  1676. assert len(n_holes) == 2 and 0 <= n_holes[0] < n_holes[1]
  1677. else:
  1678. n_holes = (n_holes, n_holes)
  1679. self.n_holes = n_holes
  1680. self.fill_in = fill_in
  1681. self.with_ratio = cutout_ratio is not None
  1682. self.candidates = cutout_ratio if self.with_ratio else cutout_shape
  1683. if not isinstance(self.candidates, list):
  1684. self.candidates = [self.candidates]
  1685. @autocast_box_type()
  1686. def transform(self, results: dict) -> dict:
  1687. """Call function to drop some regions of image."""
  1688. h, w, c = results['img'].shape
  1689. n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1)
  1690. for _ in range(n_holes):
  1691. x1 = np.random.randint(0, w)
  1692. y1 = np.random.randint(0, h)
  1693. index = np.random.randint(0, len(self.candidates))
  1694. if not self.with_ratio:
  1695. cutout_w, cutout_h = self.candidates[index]
  1696. else:
  1697. cutout_w = int(self.candidates[index][0] * w)
  1698. cutout_h = int(self.candidates[index][1] * h)
  1699. x2 = np.clip(x1 + cutout_w, 0, w)
  1700. y2 = np.clip(y1 + cutout_h, 0, h)
  1701. results['img'][y1:y2, x1:x2, :] = self.fill_in
  1702. return results
  1703. def __repr__(self):
  1704. repr_str = self.__class__.__name__
  1705. repr_str += f'(n_holes={self.n_holes}, '
  1706. repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio
  1707. else f'cutout_shape={self.candidates}, ')
  1708. repr_str += f'fill_in={self.fill_in})'
  1709. return repr_str
  1710. @TRANSFORMS.register_module()
  1711. class Mosaic(BaseTransform):
  1712. """Mosaic augmentation.
  1713. Given 4 images, mosaic transform combines them into
  1714. one output image. The output image is composed of the parts from each sub-
  1715. image.
  1716. .. code:: text
  1717. mosaic transform
  1718. center_x
  1719. +------------------------------+
  1720. | pad | pad |
  1721. | +-----------+ |
  1722. | | | |
  1723. | | image1 |--------+ |
  1724. | | | | |
  1725. | | | image2 | |
  1726. center_y |----+-------------+-----------|
  1727. | | cropped | |
  1728. |pad | image3 | image4 |
  1729. | | | |
  1730. +----|-------------+-----------+
  1731. | |
  1732. +-------------+
  1733. The mosaic transform steps are as follows:
  1734. 1. Choose the mosaic center as the intersections of 4 images
  1735. 2. Get the left top image according to the index, and randomly
  1736. sample another 3 images from the custom dataset.
  1737. 3. Sub image will be cropped if image is larger than mosaic patch
  1738. Required Keys:
  1739. - img
  1740. - gt_bboxes (BaseBoxes[torch.float32]) (optional)
  1741. - gt_bboxes_labels (np.int64) (optional)
  1742. - gt_ignore_flags (bool) (optional)
  1743. - mix_results (List[dict])
  1744. Modified Keys:
  1745. - img
  1746. - img_shape
  1747. - gt_bboxes (optional)
  1748. - gt_bboxes_labels (optional)
  1749. - gt_ignore_flags (optional)
  1750. Args:
  1751. img_scale (Sequence[int]): Image size after mosaic pipeline of single
  1752. image. The shape order should be (width, height).
  1753. Defaults to (640, 640).
  1754. center_ratio_range (Sequence[float]): Center ratio range of mosaic
  1755. output. Defaults to (0.5, 1.5).
  1756. bbox_clip_border (bool, optional): Whether to clip the objects outside
  1757. the border of the image. In some dataset like MOT17, the gt bboxes
  1758. are allowed to cross the border of images. Therefore, we don't
  1759. need to clip the gt bboxes in these cases. Defaults to True.
  1760. pad_val (int): Pad value. Defaults to 114.
  1761. prob (float): Probability of applying this transformation.
  1762. Defaults to 1.0.
  1763. """
  1764. def __init__(self,
  1765. img_scale: Tuple[int, int] = (640, 640),
  1766. center_ratio_range: Tuple[float, float] = (0.5, 1.5),
  1767. bbox_clip_border: bool = True,
  1768. pad_val: float = 114.0,
  1769. prob: float = 1.0) -> None:
  1770. assert isinstance(img_scale, tuple)
  1771. assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. ' \
  1772. f'got {prob}.'
  1773. log_img_scale(img_scale, skip_square=True, shape_order='wh')
  1774. self.img_scale = img_scale
  1775. self.center_ratio_range = center_ratio_range
  1776. self.bbox_clip_border = bbox_clip_border
  1777. self.pad_val = pad_val
  1778. self.prob = prob
  1779. @cache_randomness
  1780. def get_indexes(self, dataset: BaseDataset) -> int:
  1781. """Call function to collect indexes.
  1782. Args:
  1783. dataset (:obj:`MultiImageMixDataset`): The dataset.
  1784. Returns:
  1785. list: indexes.
  1786. """
  1787. indexes = [random.randint(0, len(dataset)) for _ in range(3)]
  1788. return indexes
  1789. @autocast_box_type()
  1790. def transform(self, results: dict) -> dict:
  1791. """Mosaic transform function.
  1792. Args:
  1793. results (dict): Result dict.
  1794. Returns:
  1795. dict: Updated result dict.
  1796. """
  1797. if random.uniform(0, 1) > self.prob:
  1798. return results
  1799. assert 'mix_results' in results
  1800. mosaic_bboxes = []
  1801. mosaic_bboxes_labels = []
  1802. mosaic_ignore_flags = []
  1803. if len(results['img'].shape) == 3:
  1804. mosaic_img = np.full(
  1805. (int(self.img_scale[1] * 2), int(self.img_scale[0] * 2), 3),
  1806. self.pad_val,
  1807. dtype=results['img'].dtype)
  1808. else:
  1809. mosaic_img = np.full(
  1810. (int(self.img_scale[1] * 2), int(self.img_scale[0] * 2)),
  1811. self.pad_val,
  1812. dtype=results['img'].dtype)
  1813. # mosaic center x, y
  1814. center_x = int(
  1815. random.uniform(*self.center_ratio_range) * self.img_scale[0])
  1816. center_y = int(
  1817. random.uniform(*self.center_ratio_range) * self.img_scale[1])
  1818. center_position = (center_x, center_y)
  1819. loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
  1820. for i, loc in enumerate(loc_strs):
  1821. if loc == 'top_left':
  1822. results_patch = copy.deepcopy(results)
  1823. else:
  1824. results_patch = copy.deepcopy(results['mix_results'][i - 1])
  1825. img_i = results_patch['img']
  1826. h_i, w_i = img_i.shape[:2]
  1827. # keep_ratio resize
  1828. scale_ratio_i = min(self.img_scale[1] / h_i,
  1829. self.img_scale[0] / w_i)
  1830. img_i = mmcv.imresize(
  1831. img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)))
  1832. # compute the combine parameters
  1833. paste_coord, crop_coord = self._mosaic_combine(
  1834. loc, center_position, img_i.shape[:2][::-1])
  1835. x1_p, y1_p, x2_p, y2_p = paste_coord
  1836. x1_c, y1_c, x2_c, y2_c = crop_coord
  1837. # crop and paste image
  1838. mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c]
  1839. # adjust coordinate
  1840. gt_bboxes_i = results_patch['gt_bboxes']
  1841. gt_bboxes_labels_i = results_patch['gt_bboxes_labels']
  1842. gt_ignore_flags_i = results_patch['gt_ignore_flags']
  1843. padw = x1_p - x1_c
  1844. padh = y1_p - y1_c
  1845. gt_bboxes_i.rescale_([scale_ratio_i, scale_ratio_i])
  1846. gt_bboxes_i.translate_([padw, padh])
  1847. mosaic_bboxes.append(gt_bboxes_i)
  1848. mosaic_bboxes_labels.append(gt_bboxes_labels_i)
  1849. mosaic_ignore_flags.append(gt_ignore_flags_i)
  1850. mosaic_bboxes = mosaic_bboxes[0].cat(mosaic_bboxes, 0)
  1851. mosaic_bboxes_labels = np.concatenate(mosaic_bboxes_labels, 0)
  1852. mosaic_ignore_flags = np.concatenate(mosaic_ignore_flags, 0)
  1853. if self.bbox_clip_border:
  1854. mosaic_bboxes.clip_([2 * self.img_scale[1], 2 * self.img_scale[0]])
  1855. # remove outside bboxes
  1856. inside_inds = mosaic_bboxes.is_inside(
  1857. [2 * self.img_scale[1], 2 * self.img_scale[0]]).numpy()
  1858. mosaic_bboxes = mosaic_bboxes[inside_inds]
  1859. mosaic_bboxes_labels = mosaic_bboxes_labels[inside_inds]
  1860. mosaic_ignore_flags = mosaic_ignore_flags[inside_inds]
  1861. results['img'] = mosaic_img
  1862. results['img_shape'] = mosaic_img.shape[:2]
  1863. results['gt_bboxes'] = mosaic_bboxes
  1864. results['gt_bboxes_labels'] = mosaic_bboxes_labels
  1865. results['gt_ignore_flags'] = mosaic_ignore_flags
  1866. return results
  1867. def _mosaic_combine(
  1868. self, loc: str, center_position_xy: Sequence[float],
  1869. img_shape_wh: Sequence[int]) -> Tuple[Tuple[int], Tuple[int]]:
  1870. """Calculate global coordinate of mosaic image and local coordinate of
  1871. cropped sub-image.
  1872. Args:
  1873. loc (str): Index for the sub-image, loc in ('top_left',
  1874. 'top_right', 'bottom_left', 'bottom_right').
  1875. center_position_xy (Sequence[float]): Mixing center for 4 images,
  1876. (x, y).
  1877. img_shape_wh (Sequence[int]): Width and height of sub-image
  1878. Returns:
  1879. tuple[tuple[float]]: Corresponding coordinate of pasting and
  1880. cropping
  1881. - paste_coord (tuple): paste corner coordinate in mosaic image.
  1882. - crop_coord (tuple): crop corner coordinate in mosaic image.
  1883. """
  1884. assert loc in ('top_left', 'top_right', 'bottom_left', 'bottom_right')
  1885. if loc == 'top_left':
  1886. # index0 to top left part of image
  1887. x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
  1888. max(center_position_xy[1] - img_shape_wh[1], 0), \
  1889. center_position_xy[0], \
  1890. center_position_xy[1]
  1891. crop_coord = img_shape_wh[0] - (x2 - x1), img_shape_wh[1] - (
  1892. y2 - y1), img_shape_wh[0], img_shape_wh[1]
  1893. elif loc == 'top_right':
  1894. # index1 to top right part of image
  1895. x1, y1, x2, y2 = center_position_xy[0], \
  1896. max(center_position_xy[1] - img_shape_wh[1], 0), \
  1897. min(center_position_xy[0] + img_shape_wh[0],
  1898. self.img_scale[0] * 2), \
  1899. center_position_xy[1]
  1900. crop_coord = 0, img_shape_wh[1] - (y2 - y1), min(
  1901. img_shape_wh[0], x2 - x1), img_shape_wh[1]
  1902. elif loc == 'bottom_left':
  1903. # index2 to bottom left part of image
  1904. x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
  1905. center_position_xy[1], \
  1906. center_position_xy[0], \
  1907. min(self.img_scale[1] * 2, center_position_xy[1] +
  1908. img_shape_wh[1])
  1909. crop_coord = img_shape_wh[0] - (x2 - x1), 0, img_shape_wh[0], min(
  1910. y2 - y1, img_shape_wh[1])
  1911. else:
  1912. # index3 to bottom right part of image
  1913. x1, y1, x2, y2 = center_position_xy[0], \
  1914. center_position_xy[1], \
  1915. min(center_position_xy[0] + img_shape_wh[0],
  1916. self.img_scale[0] * 2), \
  1917. min(self.img_scale[1] * 2, center_position_xy[1] +
  1918. img_shape_wh[1])
  1919. crop_coord = 0, 0, min(img_shape_wh[0],
  1920. x2 - x1), min(y2 - y1, img_shape_wh[1])
  1921. paste_coord = x1, y1, x2, y2
  1922. return paste_coord, crop_coord
  1923. def __repr__(self):
  1924. repr_str = self.__class__.__name__
  1925. repr_str += f'(img_scale={self.img_scale}, '
  1926. repr_str += f'center_ratio_range={self.center_ratio_range}, '
  1927. repr_str += f'pad_val={self.pad_val}, '
  1928. repr_str += f'prob={self.prob})'
  1929. return repr_str
  1930. @TRANSFORMS.register_module()
  1931. class MixUp(BaseTransform):
  1932. """MixUp data augmentation.
  1933. .. code:: text
  1934. mixup transform
  1935. +------------------------------+
  1936. | mixup image | |
  1937. | +--------|--------+ |
  1938. | | | | |
  1939. |---------------+ | |
  1940. | | | |
  1941. | | image | |
  1942. | | | |
  1943. | | | |
  1944. | |-----------------+ |
  1945. | pad |
  1946. +------------------------------+
  1947. The mixup transform steps are as follows:
  1948. 1. Another random image is picked by dataset and embedded in
  1949. the top left patch(after padding and resizing)
  1950. 2. The target of mixup transform is the weighted average of mixup
  1951. image and origin image.
  1952. Required Keys:
  1953. - img
  1954. - gt_bboxes (BaseBoxes[torch.float32]) (optional)
  1955. - gt_bboxes_labels (np.int64) (optional)
  1956. - gt_ignore_flags (bool) (optional)
  1957. - mix_results (List[dict])
  1958. Modified Keys:
  1959. - img
  1960. - img_shape
  1961. - gt_bboxes (optional)
  1962. - gt_bboxes_labels (optional)
  1963. - gt_ignore_flags (optional)
  1964. Args:
  1965. img_scale (Sequence[int]): Image output size after mixup pipeline.
  1966. The shape order should be (width, height). Defaults to (640, 640).
  1967. ratio_range (Sequence[float]): Scale ratio of mixup image.
  1968. Defaults to (0.5, 1.5).
  1969. flip_ratio (float): Horizontal flip ratio of mixup image.
  1970. Defaults to 0.5.
  1971. pad_val (int): Pad value. Defaults to 114.
  1972. max_iters (int): The maximum number of iterations. If the number of
  1973. iterations is greater than `max_iters`, but gt_bbox is still
  1974. empty, then the iteration is terminated. Defaults to 15.
  1975. bbox_clip_border (bool, optional): Whether to clip the objects outside
  1976. the border of the image. In some dataset like MOT17, the gt bboxes
  1977. are allowed to cross the border of images. Therefore, we don't
  1978. need to clip the gt bboxes in these cases. Defaults to True.
  1979. """
  1980. def __init__(self,
  1981. img_scale: Tuple[int, int] = (640, 640),
  1982. ratio_range: Tuple[float, float] = (0.5, 1.5),
  1983. flip_ratio: float = 0.5,
  1984. pad_val: float = 114.0,
  1985. max_iters: int = 15,
  1986. bbox_clip_border: bool = True) -> None:
  1987. assert isinstance(img_scale, tuple)
  1988. log_img_scale(img_scale, skip_square=True, shape_order='wh')
  1989. self.dynamic_scale = img_scale
  1990. self.ratio_range = ratio_range
  1991. self.flip_ratio = flip_ratio
  1992. self.pad_val = pad_val
  1993. self.max_iters = max_iters
  1994. self.bbox_clip_border = bbox_clip_border
  1995. @cache_randomness
  1996. def get_indexes(self, dataset: BaseDataset) -> int:
  1997. """Call function to collect indexes.
  1998. Args:
  1999. dataset (:obj:`MultiImageMixDataset`): The dataset.
  2000. Returns:
  2001. list: indexes.
  2002. """
  2003. for i in range(self.max_iters):
  2004. index = random.randint(0, len(dataset))
  2005. gt_bboxes_i = dataset[index]['gt_bboxes']
  2006. if len(gt_bboxes_i) != 0:
  2007. break
  2008. return index
  2009. @autocast_box_type()
  2010. def transform(self, results: dict) -> dict:
  2011. """MixUp transform function.
  2012. Args:
  2013. results (dict): Result dict.
  2014. Returns:
  2015. dict: Updated result dict.
  2016. """
  2017. assert 'mix_results' in results
  2018. assert len(
  2019. results['mix_results']) == 1, 'MixUp only support 2 images now !'
  2020. if results['mix_results'][0]['gt_bboxes'].shape[0] == 0:
  2021. # empty bbox
  2022. return results
  2023. retrieve_results = results['mix_results'][0]
  2024. retrieve_img = retrieve_results['img']
  2025. jit_factor = random.uniform(*self.ratio_range)
  2026. is_filp = random.uniform(0, 1) > self.flip_ratio
  2027. if len(retrieve_img.shape) == 3:
  2028. out_img = np.ones(
  2029. (self.dynamic_scale[1], self.dynamic_scale[0], 3),
  2030. dtype=retrieve_img.dtype) * self.pad_val
  2031. else:
  2032. out_img = np.ones(
  2033. self.dynamic_scale[::-1],
  2034. dtype=retrieve_img.dtype) * self.pad_val
  2035. # 1. keep_ratio resize
  2036. scale_ratio = min(self.dynamic_scale[1] / retrieve_img.shape[0],
  2037. self.dynamic_scale[0] / retrieve_img.shape[1])
  2038. retrieve_img = mmcv.imresize(
  2039. retrieve_img, (int(retrieve_img.shape[1] * scale_ratio),
  2040. int(retrieve_img.shape[0] * scale_ratio)))
  2041. # 2. paste
  2042. out_img[:retrieve_img.shape[0], :retrieve_img.shape[1]] = retrieve_img
  2043. # 3. scale jit
  2044. scale_ratio *= jit_factor
  2045. out_img = mmcv.imresize(out_img, (int(out_img.shape[1] * jit_factor),
  2046. int(out_img.shape[0] * jit_factor)))
  2047. # 4. flip
  2048. if is_filp:
  2049. out_img = out_img[:, ::-1, :]
  2050. # 5. random crop
  2051. ori_img = results['img']
  2052. origin_h, origin_w = out_img.shape[:2]
  2053. target_h, target_w = ori_img.shape[:2]
  2054. padded_img = np.ones((max(origin_h, target_h), max(
  2055. origin_w, target_w), 3)) * self.pad_val
  2056. padded_img = padded_img.astype(np.uint8)
  2057. padded_img[:origin_h, :origin_w] = out_img
  2058. x_offset, y_offset = 0, 0
  2059. if padded_img.shape[0] > target_h:
  2060. y_offset = random.randint(0, padded_img.shape[0] - target_h)
  2061. if padded_img.shape[1] > target_w:
  2062. x_offset = random.randint(0, padded_img.shape[1] - target_w)
  2063. padded_cropped_img = padded_img[y_offset:y_offset + target_h,
  2064. x_offset:x_offset + target_w]
  2065. # 6. adjust bbox
  2066. retrieve_gt_bboxes = retrieve_results['gt_bboxes']
  2067. retrieve_gt_bboxes.rescale_([scale_ratio, scale_ratio])
  2068. if self.bbox_clip_border:
  2069. retrieve_gt_bboxes.clip_([origin_h, origin_w])
  2070. if is_filp:
  2071. retrieve_gt_bboxes.flip_([origin_h, origin_w],
  2072. direction='horizontal')
  2073. # 7. filter
  2074. cp_retrieve_gt_bboxes = retrieve_gt_bboxes.clone()
  2075. cp_retrieve_gt_bboxes.translate_([-x_offset, -y_offset])
  2076. if self.bbox_clip_border:
  2077. cp_retrieve_gt_bboxes.clip_([target_h, target_w])
  2078. # 8. mix up
  2079. ori_img = ori_img.astype(np.float32)
  2080. mixup_img = 0.5 * ori_img + 0.5 * padded_cropped_img.astype(np.float32)
  2081. retrieve_gt_bboxes_labels = retrieve_results['gt_bboxes_labels']
  2082. retrieve_gt_ignore_flags = retrieve_results['gt_ignore_flags']
  2083. mixup_gt_bboxes = cp_retrieve_gt_bboxes.cat(
  2084. (results['gt_bboxes'], cp_retrieve_gt_bboxes), dim=0)
  2085. mixup_gt_bboxes_labels = np.concatenate(
  2086. (results['gt_bboxes_labels'], retrieve_gt_bboxes_labels), axis=0)
  2087. mixup_gt_ignore_flags = np.concatenate(
  2088. (results['gt_ignore_flags'], retrieve_gt_ignore_flags), axis=0)
  2089. # remove outside bbox
  2090. inside_inds = mixup_gt_bboxes.is_inside([target_h, target_w]).numpy()
  2091. mixup_gt_bboxes = mixup_gt_bboxes[inside_inds]
  2092. mixup_gt_bboxes_labels = mixup_gt_bboxes_labels[inside_inds]
  2093. mixup_gt_ignore_flags = mixup_gt_ignore_flags[inside_inds]
  2094. results['img'] = mixup_img.astype(np.uint8)
  2095. results['img_shape'] = mixup_img.shape[:2]
  2096. results['gt_bboxes'] = mixup_gt_bboxes
  2097. results['gt_bboxes_labels'] = mixup_gt_bboxes_labels
  2098. results['gt_ignore_flags'] = mixup_gt_ignore_flags
  2099. return results
  2100. def __repr__(self):
  2101. repr_str = self.__class__.__name__
  2102. repr_str += f'(dynamic_scale={self.dynamic_scale}, '
  2103. repr_str += f'ratio_range={self.ratio_range}, '
  2104. repr_str += f'flip_ratio={self.flip_ratio}, '
  2105. repr_str += f'pad_val={self.pad_val}, '
  2106. repr_str += f'max_iters={self.max_iters}, '
  2107. repr_str += f'bbox_clip_border={self.bbox_clip_border})'
  2108. return repr_str
  2109. @TRANSFORMS.register_module()
  2110. class RandomAffine(BaseTransform):
  2111. """Random affine transform data augmentation.
  2112. This operation randomly generates affine transform matrix which including
  2113. rotation, translation, shear and scaling transforms.
  2114. Required Keys:
  2115. - img
  2116. - gt_bboxes (BaseBoxes[torch.float32]) (optional)
  2117. - gt_bboxes_labels (np.int64) (optional)
  2118. - gt_ignore_flags (bool) (optional)
  2119. Modified Keys:
  2120. - img
  2121. - img_shape
  2122. - gt_bboxes (optional)
  2123. - gt_bboxes_labels (optional)
  2124. - gt_ignore_flags (optional)
  2125. Args:
  2126. max_rotate_degree (float): Maximum degrees of rotation transform.
  2127. Defaults to 10.
  2128. max_translate_ratio (float): Maximum ratio of translation.
  2129. Defaults to 0.1.
  2130. scaling_ratio_range (tuple[float]): Min and max ratio of
  2131. scaling transform. Defaults to (0.5, 1.5).
  2132. max_shear_degree (float): Maximum degrees of shear
  2133. transform. Defaults to 2.
  2134. border (tuple[int]): Distance from width and height sides of input
  2135. image to adjust output shape. Only used in mosaic dataset.
  2136. Defaults to (0, 0).
  2137. border_val (tuple[int]): Border padding values of 3 channels.
  2138. Defaults to (114, 114, 114).
  2139. bbox_clip_border (bool, optional): Whether to clip the objects outside
  2140. the border of the image. In some dataset like MOT17, the gt bboxes
  2141. are allowed to cross the border of images. Therefore, we don't
  2142. need to clip the gt bboxes in these cases. Defaults to True.
  2143. """
  2144. def __init__(self,
  2145. max_rotate_degree: float = 10.0,
  2146. max_translate_ratio: float = 0.1,
  2147. scaling_ratio_range: Tuple[float, float] = (0.5, 1.5),
  2148. max_shear_degree: float = 2.0,
  2149. border: Tuple[int, int] = (0, 0),
  2150. border_val: Tuple[int, int, int] = (114, 114, 114),
  2151. bbox_clip_border: bool = True) -> None:
  2152. assert 0 <= max_translate_ratio <= 1
  2153. assert scaling_ratio_range[0] <= scaling_ratio_range[1]
  2154. assert scaling_ratio_range[0] > 0
  2155. self.max_rotate_degree = max_rotate_degree
  2156. self.max_translate_ratio = max_translate_ratio
  2157. self.scaling_ratio_range = scaling_ratio_range
  2158. self.max_shear_degree = max_shear_degree
  2159. self.border = border
  2160. self.border_val = border_val
  2161. self.bbox_clip_border = bbox_clip_border
  2162. @cache_randomness
  2163. def _get_random_homography_matrix(self, height, width):
  2164. # Rotation
  2165. rotation_degree = random.uniform(-self.max_rotate_degree,
  2166. self.max_rotate_degree)
  2167. rotation_matrix = self._get_rotation_matrix(rotation_degree)
  2168. # Scaling
  2169. scaling_ratio = random.uniform(self.scaling_ratio_range[0],
  2170. self.scaling_ratio_range[1])
  2171. scaling_matrix = self._get_scaling_matrix(scaling_ratio)
  2172. # Shear
  2173. x_degree = random.uniform(-self.max_shear_degree,
  2174. self.max_shear_degree)
  2175. y_degree = random.uniform(-self.max_shear_degree,
  2176. self.max_shear_degree)
  2177. shear_matrix = self._get_shear_matrix(x_degree, y_degree)
  2178. # Translation
  2179. trans_x = random.uniform(-self.max_translate_ratio,
  2180. self.max_translate_ratio) * width
  2181. trans_y = random.uniform(-self.max_translate_ratio,
  2182. self.max_translate_ratio) * height
  2183. translate_matrix = self._get_translation_matrix(trans_x, trans_y)
  2184. warp_matrix = (
  2185. translate_matrix @ shear_matrix @ rotation_matrix @ scaling_matrix)
  2186. return warp_matrix
  2187. @autocast_box_type()
  2188. def transform(self, results: dict) -> dict:
  2189. img = results['img']
  2190. height = img.shape[0] + self.border[1] * 2
  2191. width = img.shape[1] + self.border[0] * 2
  2192. warp_matrix = self._get_random_homography_matrix(height, width)
  2193. img = cv2.warpPerspective(
  2194. img,
  2195. warp_matrix,
  2196. dsize=(width, height),
  2197. borderValue=self.border_val)
  2198. results['img'] = img
  2199. results['img_shape'] = img.shape[:2]
  2200. bboxes = results['gt_bboxes']
  2201. num_bboxes = len(bboxes)
  2202. if num_bboxes:
  2203. bboxes.project_(warp_matrix)
  2204. if self.bbox_clip_border:
  2205. bboxes.clip_([height, width])
  2206. # remove outside bbox
  2207. valid_index = bboxes.is_inside([height, width]).numpy()
  2208. results['gt_bboxes'] = bboxes[valid_index]
  2209. results['gt_bboxes_labels'] = results['gt_bboxes_labels'][
  2210. valid_index]
  2211. results['gt_ignore_flags'] = results['gt_ignore_flags'][
  2212. valid_index]
  2213. if 'gt_masks' in results:
  2214. raise NotImplementedError('RandomAffine only supports bbox.')
  2215. return results
  2216. def __repr__(self):
  2217. repr_str = self.__class__.__name__
  2218. repr_str += f'(max_rotate_degree={self.max_rotate_degree}, '
  2219. repr_str += f'max_translate_ratio={self.max_translate_ratio}, '
  2220. repr_str += f'scaling_ratio_range={self.scaling_ratio_range}, '
  2221. repr_str += f'max_shear_degree={self.max_shear_degree}, '
  2222. repr_str += f'border={self.border}, '
  2223. repr_str += f'border_val={self.border_val}, '
  2224. repr_str += f'bbox_clip_border={self.bbox_clip_border})'
  2225. return repr_str
  2226. @staticmethod
  2227. def _get_rotation_matrix(rotate_degrees: float) -> np.ndarray:
  2228. radian = math.radians(rotate_degrees)
  2229. rotation_matrix = np.array(
  2230. [[np.cos(radian), -np.sin(radian), 0.],
  2231. [np.sin(radian), np.cos(radian), 0.], [0., 0., 1.]],
  2232. dtype=np.float32)
  2233. return rotation_matrix
  2234. @staticmethod
  2235. def _get_scaling_matrix(scale_ratio: float) -> np.ndarray:
  2236. scaling_matrix = np.array(
  2237. [[scale_ratio, 0., 0.], [0., scale_ratio, 0.], [0., 0., 1.]],
  2238. dtype=np.float32)
  2239. return scaling_matrix
  2240. @staticmethod
  2241. def _get_shear_matrix(x_shear_degrees: float,
  2242. y_shear_degrees: float) -> np.ndarray:
  2243. x_radian = math.radians(x_shear_degrees)
  2244. y_radian = math.radians(y_shear_degrees)
  2245. shear_matrix = np.array([[1, np.tan(x_radian), 0.],
  2246. [np.tan(y_radian), 1, 0.], [0., 0., 1.]],
  2247. dtype=np.float32)
  2248. return shear_matrix
  2249. @staticmethod
  2250. def _get_translation_matrix(x: float, y: float) -> np.ndarray:
  2251. translation_matrix = np.array([[1, 0., x], [0., 1, y], [0., 0., 1.]],
  2252. dtype=np.float32)
  2253. return translation_matrix
  2254. @TRANSFORMS.register_module()
  2255. class YOLOXHSVRandomAug(BaseTransform):
  2256. """Apply HSV augmentation to image sequentially. It is referenced from
  2257. https://github.com/Megvii-
  2258. BaseDetection/YOLOX/blob/main/yolox/data/data_augment.py#L21.
  2259. Required Keys:
  2260. - img
  2261. Modified Keys:
  2262. - img
  2263. Args:
  2264. hue_delta (int): delta of hue. Defaults to 5.
  2265. saturation_delta (int): delta of saturation. Defaults to 30.
  2266. value_delta (int): delat of value. Defaults to 30.
  2267. """
  2268. def __init__(self,
  2269. hue_delta: int = 5,
  2270. saturation_delta: int = 30,
  2271. value_delta: int = 30) -> None:
  2272. self.hue_delta = hue_delta
  2273. self.saturation_delta = saturation_delta
  2274. self.value_delta = value_delta
  2275. @cache_randomness
  2276. def _get_hsv_gains(self):
  2277. hsv_gains = np.random.uniform(-1, 1, 3) * [
  2278. self.hue_delta, self.saturation_delta, self.value_delta
  2279. ]
  2280. # random selection of h, s, v
  2281. hsv_gains *= np.random.randint(0, 2, 3)
  2282. # prevent overflow
  2283. hsv_gains = hsv_gains.astype(np.int16)
  2284. return hsv_gains
  2285. def transform(self, results: dict) -> dict:
  2286. img = results['img']
  2287. hsv_gains = self._get_hsv_gains()
  2288. img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.int16)
  2289. img_hsv[..., 0] = (img_hsv[..., 0] + hsv_gains[0]) % 180
  2290. img_hsv[..., 1] = np.clip(img_hsv[..., 1] + hsv_gains[1], 0, 255)
  2291. img_hsv[..., 2] = np.clip(img_hsv[..., 2] + hsv_gains[2], 0, 255)
  2292. cv2.cvtColor(img_hsv.astype(img.dtype), cv2.COLOR_HSV2BGR, dst=img)
  2293. results['img'] = img
  2294. return results
  2295. def __repr__(self):
  2296. repr_str = self.__class__.__name__
  2297. repr_str += f'(hue_delta={self.hue_delta}, '
  2298. repr_str += f'saturation_delta={self.saturation_delta}, '
  2299. repr_str += f'value_delta={self.value_delta})'
  2300. return repr_str
  2301. @TRANSFORMS.register_module()
  2302. class CopyPaste(BaseTransform):
  2303. """Simple Copy-Paste is a Strong Data Augmentation Method for Instance
  2304. Segmentation The simple copy-paste transform steps are as follows:
  2305. 1. The destination image is already resized with aspect ratio kept,
  2306. cropped and padded.
  2307. 2. Randomly select a source image, which is also already resized
  2308. with aspect ratio kept, cropped and padded in a similar way
  2309. as the destination image.
  2310. 3. Randomly select some objects from the source image.
  2311. 4. Paste these source objects to the destination image directly,
  2312. due to the source and destination image have the same size.
  2313. 5. Update object masks of the destination image, for some origin objects
  2314. may be occluded.
  2315. 6. Generate bboxes from the updated destination masks and
  2316. filter some objects which are totally occluded, and adjust bboxes
  2317. which are partly occluded.
  2318. 7. Append selected source bboxes, masks, and labels.
  2319. Required Keys:
  2320. - img
  2321. - gt_bboxes (BaseBoxes[torch.float32]) (optional)
  2322. - gt_bboxes_labels (np.int64) (optional)
  2323. - gt_ignore_flags (bool) (optional)
  2324. - gt_masks (BitmapMasks) (optional)
  2325. Modified Keys:
  2326. - img
  2327. - gt_bboxes (optional)
  2328. - gt_bboxes_labels (optional)
  2329. - gt_ignore_flags (optional)
  2330. - gt_masks (optional)
  2331. Args:
  2332. max_num_pasted (int): The maximum number of pasted objects.
  2333. Defaults to 100.
  2334. bbox_occluded_thr (int): The threshold of occluded bbox.
  2335. Defaults to 10.
  2336. mask_occluded_thr (int): The threshold of occluded mask.
  2337. Defaults to 300.
  2338. selected (bool): Whether select objects or not. If select is False,
  2339. all objects of the source image will be pasted to the
  2340. destination image.
  2341. Defaults to True.
  2342. """
  2343. def __init__(
  2344. self,
  2345. max_num_pasted: int = 100,
  2346. bbox_occluded_thr: int = 10,
  2347. mask_occluded_thr: int = 300,
  2348. selected: bool = True,
  2349. ) -> None:
  2350. self.max_num_pasted = max_num_pasted
  2351. self.bbox_occluded_thr = bbox_occluded_thr
  2352. self.mask_occluded_thr = mask_occluded_thr
  2353. self.selected = selected
  2354. @cache_randomness
  2355. def get_indexes(self, dataset: BaseDataset) -> int:
  2356. """Call function to collect indexes.s.
  2357. Args:
  2358. dataset (:obj:`MultiImageMixDataset`): The dataset.
  2359. Returns:
  2360. list: Indexes.
  2361. """
  2362. return random.randint(0, len(dataset))
  2363. @autocast_box_type()
  2364. def transform(self, results: dict) -> dict:
  2365. """Transform function to make a copy-paste of image.
  2366. Args:
  2367. results (dict): Result dict.
  2368. Returns:
  2369. dict: Result dict with copy-paste transformed.
  2370. """
  2371. assert 'mix_results' in results
  2372. num_images = len(results['mix_results'])
  2373. assert num_images == 1, \
  2374. f'CopyPaste only supports processing 2 images, got {num_images}'
  2375. if self.selected:
  2376. selected_results = self._select_object(results['mix_results'][0])
  2377. else:
  2378. selected_results = results['mix_results'][0]
  2379. return self._copy_paste(results, selected_results)
  2380. @cache_randomness
  2381. def _get_selected_inds(self, num_bboxes: int) -> np.ndarray:
  2382. max_num_pasted = min(num_bboxes + 1, self.max_num_pasted)
  2383. num_pasted = np.random.randint(0, max_num_pasted)
  2384. return np.random.choice(num_bboxes, size=num_pasted, replace=False)
  2385. def _select_object(self, results: dict) -> dict:
  2386. """Select some objects from the source results."""
  2387. bboxes = results['gt_bboxes']
  2388. labels = results['gt_bboxes_labels']
  2389. masks = results['gt_masks']
  2390. ignore_flags = results['gt_ignore_flags']
  2391. selected_inds = self._get_selected_inds(bboxes.shape[0])
  2392. selected_bboxes = bboxes[selected_inds]
  2393. selected_labels = labels[selected_inds]
  2394. selected_masks = masks[selected_inds]
  2395. selected_ignore_flags = ignore_flags[selected_inds]
  2396. results['gt_bboxes'] = selected_bboxes
  2397. results['gt_bboxes_labels'] = selected_labels
  2398. results['gt_masks'] = selected_masks
  2399. results['gt_ignore_flags'] = selected_ignore_flags
  2400. return results
  2401. def _copy_paste(self, dst_results: dict, src_results: dict) -> dict:
  2402. """CopyPaste transform function.
  2403. Args:
  2404. dst_results (dict): Result dict of the destination image.
  2405. src_results (dict): Result dict of the source image.
  2406. Returns:
  2407. dict: Updated result dict.
  2408. """
  2409. dst_img = dst_results['img']
  2410. dst_bboxes = dst_results['gt_bboxes']
  2411. dst_labels = dst_results['gt_bboxes_labels']
  2412. dst_masks = dst_results['gt_masks']
  2413. dst_ignore_flags = dst_results['gt_ignore_flags']
  2414. src_img = src_results['img']
  2415. src_bboxes = src_results['gt_bboxes']
  2416. src_labels = src_results['gt_bboxes_labels']
  2417. src_masks = src_results['gt_masks']
  2418. src_ignore_flags = src_results['gt_ignore_flags']
  2419. if len(src_bboxes) == 0:
  2420. return dst_results
  2421. # update masks and generate bboxes from updated masks
  2422. composed_mask = np.where(np.any(src_masks.masks, axis=0), 1, 0)
  2423. updated_dst_masks = self._get_updated_masks(dst_masks, composed_mask)
  2424. updated_dst_bboxes = updated_dst_masks.get_bboxes(type(dst_bboxes))
  2425. assert len(updated_dst_bboxes) == len(updated_dst_masks)
  2426. # filter totally occluded objects
  2427. l1_distance = (updated_dst_bboxes.tensor - dst_bboxes.tensor).abs()
  2428. bboxes_inds = (l1_distance <= self.bbox_occluded_thr).all(
  2429. dim=-1).numpy()
  2430. masks_inds = updated_dst_masks.masks.sum(
  2431. axis=(1, 2)) > self.mask_occluded_thr
  2432. valid_inds = bboxes_inds | masks_inds
  2433. # Paste source objects to destination image directly
  2434. img = dst_img * (1 - composed_mask[..., np.newaxis]
  2435. ) + src_img * composed_mask[..., np.newaxis]
  2436. bboxes = src_bboxes.cat([updated_dst_bboxes[valid_inds], src_bboxes])
  2437. labels = np.concatenate([dst_labels[valid_inds], src_labels])
  2438. masks = np.concatenate(
  2439. [updated_dst_masks.masks[valid_inds], src_masks.masks])
  2440. ignore_flags = np.concatenate(
  2441. [dst_ignore_flags[valid_inds], src_ignore_flags])
  2442. dst_results['img'] = img
  2443. dst_results['gt_bboxes'] = bboxes
  2444. dst_results['gt_bboxes_labels'] = labels
  2445. dst_results['gt_masks'] = BitmapMasks(masks, masks.shape[1],
  2446. masks.shape[2])
  2447. dst_results['gt_ignore_flags'] = ignore_flags
  2448. return dst_results
  2449. def _get_updated_masks(self, masks: BitmapMasks,
  2450. composed_mask: np.ndarray) -> BitmapMasks:
  2451. """Update masks with composed mask."""
  2452. assert masks.masks.shape[-2:] == composed_mask.shape[-2:], \
  2453. 'Cannot compare two arrays of different size'
  2454. masks.masks = np.where(composed_mask, 0, masks.masks)
  2455. return masks
  2456. def __repr__(self):
  2457. repr_str = self.__class__.__name__
  2458. repr_str += f'(max_num_pasted={self.max_num_pasted}, '
  2459. repr_str += f'bbox_occluded_thr={self.bbox_occluded_thr}, '
  2460. repr_str += f'mask_occluded_thr={self.mask_occluded_thr}, '
  2461. repr_str += f'selected={self.selected})'
  2462. return repr_str
  2463. @TRANSFORMS.register_module()
  2464. class RandomErasing(BaseTransform):
  2465. """RandomErasing operation.
  2466. Random Erasing randomly selects a rectangle region
  2467. in an image and erases its pixels with random values.
  2468. `RandomErasing <https://arxiv.org/abs/1708.04896>`_.
  2469. Required Keys:
  2470. - img
  2471. - gt_bboxes (HorizontalBoxes[torch.float32]) (optional)
  2472. - gt_bboxes_labels (np.int64) (optional)
  2473. - gt_ignore_flags (bool) (optional)
  2474. - gt_masks (BitmapMasks) (optional)
  2475. Modified Keys:
  2476. - img
  2477. - gt_bboxes (optional)
  2478. - gt_bboxes_labels (optional)
  2479. - gt_ignore_flags (optional)
  2480. - gt_masks (optional)
  2481. Args:
  2482. n_patches (int or tuple[int, int]): Number of regions to be dropped.
  2483. If it is given as a tuple, number of patches will be randomly
  2484. selected from the closed interval [``n_patches[0]``,
  2485. ``n_patches[1]``].
  2486. ratio (float or tuple[float, float]): The ratio of erased regions.
  2487. It can be ``float`` to use a fixed ratio or ``tuple[float, float]``
  2488. to randomly choose ratio from the interval.
  2489. squared (bool): Whether to erase square region. Defaults to True.
  2490. bbox_erased_thr (float): The threshold for the maximum area proportion
  2491. of the bbox to be erased. When the proportion of the area where the
  2492. bbox is erased is greater than the threshold, the bbox will be
  2493. removed. Defaults to 0.9.
  2494. img_border_value (int or float or tuple): The filled values for
  2495. image border. If float, the same fill value will be used for
  2496. all the three channels of image. If tuple, it should be 3 elements.
  2497. Defaults to 128.
  2498. mask_border_value (int): The fill value used for masks. Defaults to 0.
  2499. seg_ignore_label (int): The fill value used for segmentation map.
  2500. Note this value must equals ``ignore_label`` in ``semantic_head``
  2501. of the corresponding config. Defaults to 255.
  2502. """
  2503. def __init__(
  2504. self,
  2505. n_patches: Union[int, Tuple[int, int]],
  2506. ratio: Union[float, Tuple[float, float]],
  2507. squared: bool = True,
  2508. bbox_erased_thr: float = 0.9,
  2509. img_border_value: Union[int, float, tuple] = 128,
  2510. mask_border_value: int = 0,
  2511. seg_ignore_label: int = 255,
  2512. ) -> None:
  2513. if isinstance(n_patches, tuple):
  2514. assert len(n_patches) == 2 and 0 <= n_patches[0] < n_patches[1]
  2515. else:
  2516. n_patches = (n_patches, n_patches)
  2517. if isinstance(ratio, tuple):
  2518. assert len(ratio) == 2 and 0 <= ratio[0] < ratio[1] <= 1
  2519. else:
  2520. ratio = (ratio, ratio)
  2521. self.n_patches = n_patches
  2522. self.ratio = ratio
  2523. self.squared = squared
  2524. self.bbox_erased_thr = bbox_erased_thr
  2525. self.img_border_value = img_border_value
  2526. self.mask_border_value = mask_border_value
  2527. self.seg_ignore_label = seg_ignore_label
  2528. @cache_randomness
  2529. def _get_patches(self, img_shape: Tuple[int, int]) -> List[list]:
  2530. """Get patches for random erasing."""
  2531. patches = []
  2532. n_patches = np.random.randint(self.n_patches[0], self.n_patches[1] + 1)
  2533. for _ in range(n_patches):
  2534. if self.squared:
  2535. ratio = np.random.random() * (self.ratio[1] -
  2536. self.ratio[0]) + self.ratio[0]
  2537. ratio = (ratio, ratio)
  2538. else:
  2539. ratio = (np.random.random() * (self.ratio[1] - self.ratio[0]) +
  2540. self.ratio[0], np.random.random() *
  2541. (self.ratio[1] - self.ratio[0]) + self.ratio[0])
  2542. ph, pw = int(img_shape[0] * ratio[0]), int(img_shape[1] * ratio[1])
  2543. px1, py1 = np.random.randint(0,
  2544. img_shape[1] - pw), np.random.randint(
  2545. 0, img_shape[0] - ph)
  2546. px2, py2 = px1 + pw, py1 + ph
  2547. patches.append([px1, py1, px2, py2])
  2548. return np.array(patches)
  2549. def _transform_img(self, results: dict, patches: List[list]) -> None:
  2550. """Random erasing the image."""
  2551. for patch in patches:
  2552. px1, py1, px2, py2 = patch
  2553. results['img'][py1:py2, px1:px2, :] = self.img_border_value
  2554. def _transform_bboxes(self, results: dict, patches: List[list]) -> None:
  2555. """Random erasing the bboxes."""
  2556. bboxes = results['gt_bboxes']
  2557. # TODO: unify the logic by using operators in BaseBoxes.
  2558. assert isinstance(bboxes, HorizontalBoxes)
  2559. bboxes = bboxes.numpy()
  2560. left_top = np.maximum(bboxes[:, None, :2], patches[:, :2])
  2561. right_bottom = np.minimum(bboxes[:, None, 2:], patches[:, 2:])
  2562. wh = np.maximum(right_bottom - left_top, 0)
  2563. inter_areas = wh[:, :, 0] * wh[:, :, 1]
  2564. bbox_areas = (bboxes[:, 2] - bboxes[:, 0]) * (
  2565. bboxes[:, 3] - bboxes[:, 1])
  2566. bboxes_erased_ratio = inter_areas.sum(-1) / (bbox_areas + 1e-7)
  2567. valid_inds = bboxes_erased_ratio < self.bbox_erased_thr
  2568. results['gt_bboxes'] = HorizontalBoxes(bboxes[valid_inds])
  2569. results['gt_bboxes_labels'] = results['gt_bboxes_labels'][valid_inds]
  2570. results['gt_ignore_flags'] = results['gt_ignore_flags'][valid_inds]
  2571. if results.get('gt_masks', None) is not None:
  2572. results['gt_masks'] = results['gt_masks'][valid_inds]
  2573. def _transform_masks(self, results: dict, patches: List[list]) -> None:
  2574. """Random erasing the masks."""
  2575. for patch in patches:
  2576. px1, py1, px2, py2 = patch
  2577. results['gt_masks'].masks[:, py1:py2,
  2578. px1:px2] = self.mask_border_value
  2579. def _transform_seg(self, results: dict, patches: List[list]) -> None:
  2580. """Random erasing the segmentation map."""
  2581. for patch in patches:
  2582. px1, py1, px2, py2 = patch
  2583. results['gt_seg_map'][py1:py2, px1:px2] = self.seg_ignore_label
  2584. @autocast_box_type()
  2585. def transform(self, results: dict) -> dict:
  2586. """Transform function to erase some regions of image."""
  2587. patches = self._get_patches(results['img_shape'])
  2588. self._transform_img(results, patches)
  2589. if results.get('gt_bboxes', None) is not None:
  2590. self._transform_bboxes(results, patches)
  2591. if results.get('gt_masks', None) is not None:
  2592. self._transform_masks(results, patches)
  2593. if results.get('gt_seg_map', None) is not None:
  2594. self._transform_seg(results, patches)
  2595. return results
  2596. def __repr__(self):
  2597. repr_str = self.__class__.__name__
  2598. repr_str += f'(n_patches={self.n_patches}, '
  2599. repr_str += f'ratio={self.ratio}, '
  2600. repr_str += f'squared={self.squared}, '
  2601. repr_str += f'bbox_erased_thr={self.bbox_erased_thr}, '
  2602. repr_str += f'img_border_value={self.img_border_value}, '
  2603. repr_str += f'mask_border_value={self.mask_border_value}, '
  2604. repr_str += f'seg_ignore_label={self.seg_ignore_label})'
  2605. return repr_str
  2606. @TRANSFORMS.register_module()
  2607. class CachedMosaic(Mosaic):
  2608. """Cached mosaic augmentation.
  2609. Cached mosaic transform will random select images from the cache
  2610. and combine them into one output image.
  2611. .. code:: text
  2612. mosaic transform
  2613. center_x
  2614. +------------------------------+
  2615. | pad | pad |
  2616. | +-----------+ |
  2617. | | | |
  2618. | | image1 |--------+ |
  2619. | | | | |
  2620. | | | image2 | |
  2621. center_y |----+-------------+-----------|
  2622. | | cropped | |
  2623. |pad | image3 | image4 |
  2624. | | | |
  2625. +----|-------------+-----------+
  2626. | |
  2627. +-------------+
  2628. The cached mosaic transform steps are as follows:
  2629. 1. Append the results from the last transform into the cache.
  2630. 2. Choose the mosaic center as the intersections of 4 images
  2631. 3. Get the left top image according to the index, and randomly
  2632. sample another 3 images from the result cache.
  2633. 4. Sub image will be cropped if image is larger than mosaic patch
  2634. Required Keys:
  2635. - img
  2636. - gt_bboxes (np.float32) (optional)
  2637. - gt_bboxes_labels (np.int64) (optional)
  2638. - gt_ignore_flags (bool) (optional)
  2639. Modified Keys:
  2640. - img
  2641. - img_shape
  2642. - gt_bboxes (optional)
  2643. - gt_bboxes_labels (optional)
  2644. - gt_ignore_flags (optional)
  2645. Args:
  2646. img_scale (Sequence[int]): Image size after mosaic pipeline of single
  2647. image. The shape order should be (width, height).
  2648. Defaults to (640, 640).
  2649. center_ratio_range (Sequence[float]): Center ratio range of mosaic
  2650. output. Defaults to (0.5, 1.5).
  2651. bbox_clip_border (bool, optional): Whether to clip the objects outside
  2652. the border of the image. In some dataset like MOT17, the gt bboxes
  2653. are allowed to cross the border of images. Therefore, we don't
  2654. need to clip the gt bboxes in these cases. Defaults to True.
  2655. pad_val (int): Pad value. Defaults to 114.
  2656. prob (float): Probability of applying this transformation.
  2657. Defaults to 1.0.
  2658. max_cached_images (int): The maximum length of the cache. The larger
  2659. the cache, the stronger the randomness of this transform. As a
  2660. rule of thumb, providing 10 caches for each image suffices for
  2661. randomness. Defaults to 40.
  2662. random_pop (bool): Whether to randomly pop a result from the cache
  2663. when the cache is full. If set to False, use FIFO popping method.
  2664. Defaults to True.
  2665. """
  2666. def __init__(self,
  2667. *args,
  2668. max_cached_images: int = 40,
  2669. random_pop: bool = True,
  2670. **kwargs) -> None:
  2671. super().__init__(*args, **kwargs)
  2672. self.results_cache = []
  2673. self.random_pop = random_pop
  2674. assert max_cached_images >= 4, 'The length of cache must >= 4, ' \
  2675. f'but got {max_cached_images}.'
  2676. self.max_cached_images = max_cached_images
  2677. @cache_randomness
  2678. def get_indexes(self, cache: list) -> list:
  2679. """Call function to collect indexes.
  2680. Args:
  2681. cache (list): The results cache.
  2682. Returns:
  2683. list: indexes.
  2684. """
  2685. indexes = [random.randint(0, len(cache) - 1) for _ in range(3)]
  2686. return indexes
  2687. @autocast_box_type()
  2688. def transform(self, results: dict) -> dict:
  2689. """Mosaic transform function.
  2690. Args:
  2691. results (dict): Result dict.
  2692. Returns:
  2693. dict: Updated result dict.
  2694. """
  2695. # cache and pop images
  2696. self.results_cache.append(copy.deepcopy(results))
  2697. if len(self.results_cache) > self.max_cached_images:
  2698. if self.random_pop:
  2699. index = random.randint(0, len(self.results_cache) - 1)
  2700. else:
  2701. index = 0
  2702. self.results_cache.pop(index)
  2703. if len(self.results_cache) <= 4:
  2704. return results
  2705. if random.uniform(0, 1) > self.prob:
  2706. return results
  2707. indices = self.get_indexes(self.results_cache)
  2708. mix_results = [copy.deepcopy(self.results_cache[i]) for i in indices]
  2709. # TODO: refactor mosaic to reuse these code.
  2710. mosaic_bboxes = []
  2711. mosaic_bboxes_labels = []
  2712. mosaic_ignore_flags = []
  2713. mosaic_masks = []
  2714. with_mask = True if 'gt_masks' in results else False
  2715. if len(results['img'].shape) == 3:
  2716. mosaic_img = np.full(
  2717. (int(self.img_scale[1] * 2), int(self.img_scale[0] * 2), 3),
  2718. self.pad_val,
  2719. dtype=results['img'].dtype)
  2720. else:
  2721. mosaic_img = np.full(
  2722. (int(self.img_scale[1] * 2), int(self.img_scale[0] * 2)),
  2723. self.pad_val,
  2724. dtype=results['img'].dtype)
  2725. # mosaic center x, y
  2726. center_x = int(
  2727. random.uniform(*self.center_ratio_range) * self.img_scale[0])
  2728. center_y = int(
  2729. random.uniform(*self.center_ratio_range) * self.img_scale[1])
  2730. center_position = (center_x, center_y)
  2731. loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
  2732. for i, loc in enumerate(loc_strs):
  2733. if loc == 'top_left':
  2734. results_patch = copy.deepcopy(results)
  2735. else:
  2736. results_patch = copy.deepcopy(mix_results[i - 1])
  2737. img_i = results_patch['img']
  2738. h_i, w_i = img_i.shape[:2]
  2739. # keep_ratio resize
  2740. scale_ratio_i = min(self.img_scale[1] / h_i,
  2741. self.img_scale[0] / w_i)
  2742. img_i = mmcv.imresize(
  2743. img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)))
  2744. # compute the combine parameters
  2745. paste_coord, crop_coord = self._mosaic_combine(
  2746. loc, center_position, img_i.shape[:2][::-1])
  2747. x1_p, y1_p, x2_p, y2_p = paste_coord
  2748. x1_c, y1_c, x2_c, y2_c = crop_coord
  2749. # crop and paste image
  2750. mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c]
  2751. # adjust coordinate
  2752. gt_bboxes_i = results_patch['gt_bboxes']
  2753. gt_bboxes_labels_i = results_patch['gt_bboxes_labels']
  2754. gt_ignore_flags_i = results_patch['gt_ignore_flags']
  2755. padw = x1_p - x1_c
  2756. padh = y1_p - y1_c
  2757. gt_bboxes_i.rescale_([scale_ratio_i, scale_ratio_i])
  2758. gt_bboxes_i.translate_([padw, padh])
  2759. mosaic_bboxes.append(gt_bboxes_i)
  2760. mosaic_bboxes_labels.append(gt_bboxes_labels_i)
  2761. mosaic_ignore_flags.append(gt_ignore_flags_i)
  2762. if with_mask and results_patch.get('gt_masks', None) is not None:
  2763. gt_masks_i = results_patch['gt_masks']
  2764. gt_masks_i = gt_masks_i.rescale(float(scale_ratio_i))
  2765. gt_masks_i = gt_masks_i.translate(
  2766. out_shape=(int(self.img_scale[0] * 2),
  2767. int(self.img_scale[1] * 2)),
  2768. offset=padw,
  2769. direction='horizontal')
  2770. gt_masks_i = gt_masks_i.translate(
  2771. out_shape=(int(self.img_scale[0] * 2),
  2772. int(self.img_scale[1] * 2)),
  2773. offset=padh,
  2774. direction='vertical')
  2775. mosaic_masks.append(gt_masks_i)
  2776. mosaic_bboxes = mosaic_bboxes[0].cat(mosaic_bboxes, 0)
  2777. mosaic_bboxes_labels = np.concatenate(mosaic_bboxes_labels, 0)
  2778. mosaic_ignore_flags = np.concatenate(mosaic_ignore_flags, 0)
  2779. if self.bbox_clip_border:
  2780. mosaic_bboxes.clip_([2 * self.img_scale[1], 2 * self.img_scale[0]])
  2781. # remove outside bboxes
  2782. inside_inds = mosaic_bboxes.is_inside(
  2783. [2 * self.img_scale[1], 2 * self.img_scale[0]]).numpy()
  2784. mosaic_bboxes = mosaic_bboxes[inside_inds]
  2785. mosaic_bboxes_labels = mosaic_bboxes_labels[inside_inds]
  2786. mosaic_ignore_flags = mosaic_ignore_flags[inside_inds]
  2787. results['img'] = mosaic_img
  2788. results['img_shape'] = mosaic_img.shape[:2]
  2789. results['gt_bboxes'] = mosaic_bboxes
  2790. results['gt_bboxes_labels'] = mosaic_bboxes_labels
  2791. results['gt_ignore_flags'] = mosaic_ignore_flags
  2792. if with_mask:
  2793. mosaic_masks = mosaic_masks[0].cat(mosaic_masks)
  2794. results['gt_masks'] = mosaic_masks[inside_inds]
  2795. return results
  2796. def __repr__(self):
  2797. repr_str = self.__class__.__name__
  2798. repr_str += f'(img_scale={self.img_scale}, '
  2799. repr_str += f'center_ratio_range={self.center_ratio_range}, '
  2800. repr_str += f'pad_val={self.pad_val}, '
  2801. repr_str += f'prob={self.prob}, '
  2802. repr_str += f'max_cached_images={self.max_cached_images}, '
  2803. repr_str += f'random_pop={self.random_pop})'
  2804. return repr_str
  2805. @TRANSFORMS.register_module()
  2806. class CachedMixUp(BaseTransform):
  2807. """Cached mixup data augmentation.
  2808. .. code:: text
  2809. mixup transform
  2810. +------------------------------+
  2811. | mixup image | |
  2812. | +--------|--------+ |
  2813. | | | | |
  2814. |---------------+ | |
  2815. | | | |
  2816. | | image | |
  2817. | | | |
  2818. | | | |
  2819. | |-----------------+ |
  2820. | pad |
  2821. +------------------------------+
  2822. The cached mixup transform steps are as follows:
  2823. 1. Append the results from the last transform into the cache.
  2824. 2. Another random image is picked from the cache and embedded in
  2825. the top left patch(after padding and resizing)
  2826. 3. The target of mixup transform is the weighted average of mixup
  2827. image and origin image.
  2828. Required Keys:
  2829. - img
  2830. - gt_bboxes (np.float32) (optional)
  2831. - gt_bboxes_labels (np.int64) (optional)
  2832. - gt_ignore_flags (bool) (optional)
  2833. - mix_results (List[dict])
  2834. Modified Keys:
  2835. - img
  2836. - img_shape
  2837. - gt_bboxes (optional)
  2838. - gt_bboxes_labels (optional)
  2839. - gt_ignore_flags (optional)
  2840. Args:
  2841. img_scale (Sequence[int]): Image output size after mixup pipeline.
  2842. The shape order should be (width, height). Defaults to (640, 640).
  2843. ratio_range (Sequence[float]): Scale ratio of mixup image.
  2844. Defaults to (0.5, 1.5).
  2845. flip_ratio (float): Horizontal flip ratio of mixup image.
  2846. Defaults to 0.5.
  2847. pad_val (int): Pad value. Defaults to 114.
  2848. max_iters (int): The maximum number of iterations. If the number of
  2849. iterations is greater than `max_iters`, but gt_bbox is still
  2850. empty, then the iteration is terminated. Defaults to 15.
  2851. bbox_clip_border (bool, optional): Whether to clip the objects outside
  2852. the border of the image. In some dataset like MOT17, the gt bboxes
  2853. are allowed to cross the border of images. Therefore, we don't
  2854. need to clip the gt bboxes in these cases. Defaults to True.
  2855. max_cached_images (int): The maximum length of the cache. The larger
  2856. the cache, the stronger the randomness of this transform. As a
  2857. rule of thumb, providing 10 caches for each image suffices for
  2858. randomness. Defaults to 20.
  2859. random_pop (bool): Whether to randomly pop a result from the cache
  2860. when the cache is full. If set to False, use FIFO popping method.
  2861. Defaults to True.
  2862. prob (float): Probability of applying this transformation.
  2863. Defaults to 1.0.
  2864. """
  2865. def __init__(self,
  2866. img_scale: Tuple[int, int] = (640, 640),
  2867. ratio_range: Tuple[float, float] = (0.5, 1.5),
  2868. flip_ratio: float = 0.5,
  2869. pad_val: float = 114.0,
  2870. max_iters: int = 15,
  2871. bbox_clip_border: bool = True,
  2872. max_cached_images: int = 20,
  2873. random_pop: bool = True,
  2874. prob: float = 1.0) -> None:
  2875. assert isinstance(img_scale, tuple)
  2876. assert max_cached_images >= 2, 'The length of cache must >= 2, ' \
  2877. f'but got {max_cached_images}.'
  2878. assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. ' \
  2879. f'got {prob}.'
  2880. self.dynamic_scale = img_scale
  2881. self.ratio_range = ratio_range
  2882. self.flip_ratio = flip_ratio
  2883. self.pad_val = pad_val
  2884. self.max_iters = max_iters
  2885. self.bbox_clip_border = bbox_clip_border
  2886. self.results_cache = []
  2887. self.max_cached_images = max_cached_images
  2888. self.random_pop = random_pop
  2889. self.prob = prob
  2890. @cache_randomness
  2891. def get_indexes(self, cache: list) -> int:
  2892. """Call function to collect indexes.
  2893. Args:
  2894. cache (list): The result cache.
  2895. Returns:
  2896. int: index.
  2897. """
  2898. for i in range(self.max_iters):
  2899. index = random.randint(0, len(cache) - 1)
  2900. gt_bboxes_i = cache[index]['gt_bboxes']
  2901. if len(gt_bboxes_i) != 0:
  2902. break
  2903. return index
  2904. @autocast_box_type()
  2905. def transform(self, results: dict) -> dict:
  2906. """MixUp transform function.
  2907. Args:
  2908. results (dict): Result dict.
  2909. Returns:
  2910. dict: Updated result dict.
  2911. """
  2912. # cache and pop images
  2913. self.results_cache.append(copy.deepcopy(results))
  2914. if len(self.results_cache) > self.max_cached_images:
  2915. if self.random_pop:
  2916. index = random.randint(0, len(self.results_cache) - 1)
  2917. else:
  2918. index = 0
  2919. self.results_cache.pop(index)
  2920. if len(self.results_cache) <= 1:
  2921. return results
  2922. if random.uniform(0, 1) > self.prob:
  2923. return results
  2924. index = self.get_indexes(self.results_cache)
  2925. retrieve_results = copy.deepcopy(self.results_cache[index])
  2926. # TODO: refactor mixup to reuse these code.
  2927. if retrieve_results['gt_bboxes'].shape[0] == 0:
  2928. # empty bbox
  2929. return results
  2930. retrieve_img = retrieve_results['img']
  2931. with_mask = True if 'gt_masks' in results else False
  2932. jit_factor = random.uniform(*self.ratio_range)
  2933. is_filp = random.uniform(0, 1) > self.flip_ratio
  2934. if len(retrieve_img.shape) == 3:
  2935. out_img = np.ones(
  2936. (self.dynamic_scale[1], self.dynamic_scale[0], 3),
  2937. dtype=retrieve_img.dtype) * self.pad_val
  2938. else:
  2939. out_img = np.ones(
  2940. self.dynamic_scale[::-1],
  2941. dtype=retrieve_img.dtype) * self.pad_val
  2942. # 1. keep_ratio resize
  2943. scale_ratio = min(self.dynamic_scale[1] / retrieve_img.shape[0],
  2944. self.dynamic_scale[0] / retrieve_img.shape[1])
  2945. retrieve_img = mmcv.imresize(
  2946. retrieve_img, (int(retrieve_img.shape[1] * scale_ratio),
  2947. int(retrieve_img.shape[0] * scale_ratio)))
  2948. # 2. paste
  2949. out_img[:retrieve_img.shape[0], :retrieve_img.shape[1]] = retrieve_img
  2950. # 3. scale jit
  2951. scale_ratio *= jit_factor
  2952. out_img = mmcv.imresize(out_img, (int(out_img.shape[1] * jit_factor),
  2953. int(out_img.shape[0] * jit_factor)))
  2954. # 4. flip
  2955. if is_filp:
  2956. out_img = out_img[:, ::-1, :]
  2957. # 5. random crop
  2958. ori_img = results['img']
  2959. origin_h, origin_w = out_img.shape[:2]
  2960. target_h, target_w = ori_img.shape[:2]
  2961. padded_img = np.ones((max(origin_h, target_h), max(
  2962. origin_w, target_w), 3)) * self.pad_val
  2963. padded_img = padded_img.astype(np.uint8)
  2964. padded_img[:origin_h, :origin_w] = out_img
  2965. x_offset, y_offset = 0, 0
  2966. if padded_img.shape[0] > target_h:
  2967. y_offset = random.randint(0, padded_img.shape[0] - target_h)
  2968. if padded_img.shape[1] > target_w:
  2969. x_offset = random.randint(0, padded_img.shape[1] - target_w)
  2970. padded_cropped_img = padded_img[y_offset:y_offset + target_h,
  2971. x_offset:x_offset + target_w]
  2972. # 6. adjust bbox
  2973. retrieve_gt_bboxes = retrieve_results['gt_bboxes']
  2974. retrieve_gt_bboxes.rescale_([scale_ratio, scale_ratio])
  2975. if with_mask:
  2976. retrieve_gt_masks = retrieve_results['gt_masks'].rescale(
  2977. scale_ratio)
  2978. if self.bbox_clip_border:
  2979. retrieve_gt_bboxes.clip_([origin_h, origin_w])
  2980. if is_filp:
  2981. retrieve_gt_bboxes.flip_([origin_h, origin_w],
  2982. direction='horizontal')
  2983. if with_mask:
  2984. retrieve_gt_masks = retrieve_gt_masks.flip()
  2985. # 7. filter
  2986. cp_retrieve_gt_bboxes = retrieve_gt_bboxes.clone()
  2987. cp_retrieve_gt_bboxes.translate_([-x_offset, -y_offset])
  2988. if with_mask:
  2989. retrieve_gt_masks = retrieve_gt_masks.translate(
  2990. out_shape=(target_h, target_w),
  2991. offset=-x_offset,
  2992. direction='horizontal')
  2993. retrieve_gt_masks = retrieve_gt_masks.translate(
  2994. out_shape=(target_h, target_w),
  2995. offset=-y_offset,
  2996. direction='vertical')
  2997. if self.bbox_clip_border:
  2998. cp_retrieve_gt_bboxes.clip_([target_h, target_w])
  2999. # 8. mix up
  3000. ori_img = ori_img.astype(np.float32)
  3001. mixup_img = 0.5 * ori_img + 0.5 * padded_cropped_img.astype(np.float32)
  3002. retrieve_gt_bboxes_labels = retrieve_results['gt_bboxes_labels']
  3003. retrieve_gt_ignore_flags = retrieve_results['gt_ignore_flags']
  3004. mixup_gt_bboxes = cp_retrieve_gt_bboxes.cat(
  3005. (results['gt_bboxes'], cp_retrieve_gt_bboxes), dim=0)
  3006. mixup_gt_bboxes_labels = np.concatenate(
  3007. (results['gt_bboxes_labels'], retrieve_gt_bboxes_labels), axis=0)
  3008. mixup_gt_ignore_flags = np.concatenate(
  3009. (results['gt_ignore_flags'], retrieve_gt_ignore_flags), axis=0)
  3010. if with_mask:
  3011. mixup_gt_masks = retrieve_gt_masks.cat(
  3012. [results['gt_masks'], retrieve_gt_masks])
  3013. # remove outside bbox
  3014. inside_inds = mixup_gt_bboxes.is_inside([target_h, target_w]).numpy()
  3015. mixup_gt_bboxes = mixup_gt_bboxes[inside_inds]
  3016. mixup_gt_bboxes_labels = mixup_gt_bboxes_labels[inside_inds]
  3017. mixup_gt_ignore_flags = mixup_gt_ignore_flags[inside_inds]
  3018. if with_mask:
  3019. mixup_gt_masks = mixup_gt_masks[inside_inds]
  3020. results['img'] = mixup_img.astype(np.uint8)
  3021. results['img_shape'] = mixup_img.shape[:2]
  3022. results['gt_bboxes'] = mixup_gt_bboxes
  3023. results['gt_bboxes_labels'] = mixup_gt_bboxes_labels
  3024. results['gt_ignore_flags'] = mixup_gt_ignore_flags
  3025. if with_mask:
  3026. results['gt_masks'] = mixup_gt_masks
  3027. return results
  3028. def __repr__(self):
  3029. repr_str = self.__class__.__name__
  3030. repr_str += f'(dynamic_scale={self.dynamic_scale}, '
  3031. repr_str += f'ratio_range={self.ratio_range}, '
  3032. repr_str += f'flip_ratio={self.flip_ratio}, '
  3033. repr_str += f'pad_val={self.pad_val}, '
  3034. repr_str += f'max_iters={self.max_iters}, '
  3035. repr_str += f'bbox_clip_border={self.bbox_clip_border}, '
  3036. repr_str += f'max_cached_images={self.max_cached_images}, '
  3037. repr_str += f'random_pop={self.random_pop}, '
  3038. repr_str += f'prob={self.prob})'
  3039. return repr_str