iou_loss.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. import warnings
  4. from typing import Optional
  5. import torch
  6. import torch.nn as nn
  7. from torch import Tensor
  8. from mmdet.registry import MODELS
  9. from mmdet.structures.bbox import bbox_overlaps
  10. from .utils import weighted_loss
  11. @weighted_loss
  12. def iou_loss(pred: Tensor,
  13. target: Tensor,
  14. linear: bool = False,
  15. mode: str = 'log',
  16. eps: float = 1e-6) -> Tensor:
  17. """IoU loss.
  18. Computing the IoU loss between a set of predicted bboxes and target bboxes.
  19. The loss is calculated as negative log of IoU.
  20. Args:
  21. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  22. shape (n, 4).
  23. target (Tensor): Corresponding gt bboxes, shape (n, 4).
  24. linear (bool, optional): If True, use linear scale of loss instead of
  25. log scale. Default: False.
  26. mode (str): Loss scaling mode, including "linear", "square", and "log".
  27. Default: 'log'
  28. eps (float): Epsilon to avoid log(0).
  29. Return:
  30. Tensor: Loss tensor.
  31. """
  32. assert mode in ['linear', 'square', 'log']
  33. if linear:
  34. mode = 'linear'
  35. warnings.warn('DeprecationWarning: Setting "linear=True" in '
  36. 'iou_loss is deprecated, please use "mode=`linear`" '
  37. 'instead.')
  38. ious = bbox_overlaps(pred, target, is_aligned=True).clamp(min=eps)
  39. if mode == 'linear':
  40. loss = 1 - ious
  41. elif mode == 'square':
  42. loss = 1 - ious**2
  43. elif mode == 'log':
  44. loss = -ious.log()
  45. else:
  46. raise NotImplementedError
  47. return loss
  48. @weighted_loss
  49. def bounded_iou_loss(pred: Tensor,
  50. target: Tensor,
  51. beta: float = 0.2,
  52. eps: float = 1e-3) -> Tensor:
  53. """BIoULoss.
  54. This is an implementation of paper
  55. `Improving Object Localization with Fitness NMS and Bounded IoU Loss.
  56. <https://arxiv.org/abs/1711.00164>`_.
  57. Args:
  58. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  59. shape (n, 4).
  60. target (Tensor): Corresponding gt bboxes, shape (n, 4).
  61. beta (float, optional): Beta parameter in smoothl1.
  62. eps (float, optional): Epsilon to avoid NaN values.
  63. Return:
  64. Tensor: Loss tensor.
  65. """
  66. pred_ctrx = (pred[:, 0] + pred[:, 2]) * 0.5
  67. pred_ctry = (pred[:, 1] + pred[:, 3]) * 0.5
  68. pred_w = pred[:, 2] - pred[:, 0]
  69. pred_h = pred[:, 3] - pred[:, 1]
  70. with torch.no_grad():
  71. target_ctrx = (target[:, 0] + target[:, 2]) * 0.5
  72. target_ctry = (target[:, 1] + target[:, 3]) * 0.5
  73. target_w = target[:, 2] - target[:, 0]
  74. target_h = target[:, 3] - target[:, 1]
  75. dx = target_ctrx - pred_ctrx
  76. dy = target_ctry - pred_ctry
  77. loss_dx = 1 - torch.max(
  78. (target_w - 2 * dx.abs()) /
  79. (target_w + 2 * dx.abs() + eps), torch.zeros_like(dx))
  80. loss_dy = 1 - torch.max(
  81. (target_h - 2 * dy.abs()) /
  82. (target_h + 2 * dy.abs() + eps), torch.zeros_like(dy))
  83. loss_dw = 1 - torch.min(target_w / (pred_w + eps), pred_w /
  84. (target_w + eps))
  85. loss_dh = 1 - torch.min(target_h / (pred_h + eps), pred_h /
  86. (target_h + eps))
  87. # view(..., -1) does not work for empty tensor
  88. loss_comb = torch.stack([loss_dx, loss_dy, loss_dw, loss_dh],
  89. dim=-1).flatten(1)
  90. loss = torch.where(loss_comb < beta, 0.5 * loss_comb * loss_comb / beta,
  91. loss_comb - 0.5 * beta)
  92. return loss
  93. @weighted_loss
  94. def giou_loss(pred: Tensor, target: Tensor, eps: float = 1e-7) -> Tensor:
  95. r"""`Generalized Intersection over Union: A Metric and A Loss for Bounding
  96. Box Regression <https://arxiv.org/abs/1902.09630>`_.
  97. Args:
  98. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  99. shape (n, 4).
  100. target (Tensor): Corresponding gt bboxes, shape (n, 4).
  101. eps (float): Epsilon to avoid log(0).
  102. Return:
  103. Tensor: Loss tensor.
  104. """
  105. gious = bbox_overlaps(pred, target, mode='giou', is_aligned=True, eps=eps)
  106. loss = 1 - gious
  107. return loss
  108. @weighted_loss
  109. def diou_loss(pred: Tensor, target: Tensor, eps: float = 1e-7) -> Tensor:
  110. r"""Implementation of `Distance-IoU Loss: Faster and Better
  111. Learning for Bounding Box Regression https://arxiv.org/abs/1911.08287`_.
  112. Code is modified from https://github.com/Zzh-tju/DIoU.
  113. Args:
  114. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  115. shape (n, 4).
  116. target (Tensor): Corresponding gt bboxes, shape (n, 4).
  117. eps (float): Epsilon to avoid log(0).
  118. Return:
  119. Tensor: Loss tensor.
  120. """
  121. # overlap
  122. lt = torch.max(pred[:, :2], target[:, :2])
  123. rb = torch.min(pred[:, 2:], target[:, 2:])
  124. wh = (rb - lt).clamp(min=0)
  125. overlap = wh[:, 0] * wh[:, 1]
  126. # union
  127. ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
  128. ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
  129. union = ap + ag - overlap + eps
  130. # IoU
  131. ious = overlap / union
  132. # enclose area
  133. enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
  134. enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
  135. enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
  136. cw = enclose_wh[:, 0]
  137. ch = enclose_wh[:, 1]
  138. c2 = cw**2 + ch**2 + eps
  139. b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
  140. b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
  141. b2_x1, b2_y1 = target[:, 0], target[:, 1]
  142. b2_x2, b2_y2 = target[:, 2], target[:, 3]
  143. left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4
  144. right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4
  145. rho2 = left + right
  146. # DIoU
  147. dious = ious - rho2 / c2
  148. loss = 1 - dious
  149. return loss
  150. @weighted_loss
  151. def ciou_loss(pred: Tensor, target: Tensor, eps: float = 1e-7) -> Tensor:
  152. r"""`Implementation of paper `Enhancing Geometric Factors into
  153. Model Learning and Inference for Object Detection and Instance
  154. Segmentation <https://arxiv.org/abs/2005.03572>`_.
  155. Code is modified from https://github.com/Zzh-tju/CIoU.
  156. Args:
  157. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  158. shape (n, 4).
  159. target (Tensor): Corresponding gt bboxes, shape (n, 4).
  160. eps (float): Epsilon to avoid log(0).
  161. Return:
  162. Tensor: Loss tensor.
  163. """
  164. # overlap
  165. lt = torch.max(pred[:, :2], target[:, :2])
  166. rb = torch.min(pred[:, 2:], target[:, 2:])
  167. wh = (rb - lt).clamp(min=0)
  168. overlap = wh[:, 0] * wh[:, 1]
  169. # union
  170. ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
  171. ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
  172. union = ap + ag - overlap + eps
  173. # IoU
  174. ious = overlap / union
  175. # enclose area
  176. enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
  177. enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
  178. enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
  179. cw = enclose_wh[:, 0]
  180. ch = enclose_wh[:, 1]
  181. c2 = cw**2 + ch**2 + eps
  182. b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
  183. b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
  184. b2_x1, b2_y1 = target[:, 0], target[:, 1]
  185. b2_x2, b2_y2 = target[:, 2], target[:, 3]
  186. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  187. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  188. left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4
  189. right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4
  190. rho2 = left + right
  191. factor = 4 / math.pi**2
  192. v = factor * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  193. with torch.no_grad():
  194. alpha = (ious > 0.5).float() * v / (1 - ious + v)
  195. # CIoU
  196. cious = ious - (rho2 / c2 + alpha * v)
  197. loss = 1 - cious.clamp(min=-1.0, max=1.0)
  198. return loss
  199. @weighted_loss
  200. def eiou_loss(pred: Tensor,
  201. target: Tensor,
  202. smooth_point: float = 0.1,
  203. eps: float = 1e-7) -> Tensor:
  204. r"""Implementation of paper `Extended-IoU Loss: A Systematic
  205. IoU-Related Method: Beyond Simplified Regression for Better
  206. Localization <https://ieeexplore.ieee.org/abstract/document/9429909>`_
  207. Code is modified from https://github.com//ShiqiYu/libfacedetection.train.
  208. Args:
  209. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  210. shape (n, 4).
  211. target (Tensor): Corresponding gt bboxes, shape (n, 4).
  212. smooth_point (float): hyperparameter, default is 0.1.
  213. eps (float): Epsilon to avoid log(0).
  214. Return:
  215. Tensor: Loss tensor.
  216. """
  217. px1, py1, px2, py2 = pred[:, 0], pred[:, 1], pred[:, 2], pred[:, 3]
  218. tx1, ty1, tx2, ty2 = target[:, 0], target[:, 1], target[:, 2], target[:, 3]
  219. # extent top left
  220. ex1 = torch.min(px1, tx1)
  221. ey1 = torch.min(py1, ty1)
  222. # intersection coordinates
  223. ix1 = torch.max(px1, tx1)
  224. iy1 = torch.max(py1, ty1)
  225. ix2 = torch.min(px2, tx2)
  226. iy2 = torch.min(py2, ty2)
  227. # extra
  228. xmin = torch.min(ix1, ix2)
  229. ymin = torch.min(iy1, iy2)
  230. xmax = torch.max(ix1, ix2)
  231. ymax = torch.max(iy1, iy2)
  232. # Intersection
  233. intersection = (ix2 - ex1) * (iy2 - ey1) + (xmin - ex1) * (ymin - ey1) - (
  234. ix1 - ex1) * (ymax - ey1) - (xmax - ex1) * (
  235. iy1 - ey1)
  236. # Union
  237. union = (px2 - px1) * (py2 - py1) + (tx2 - tx1) * (
  238. ty2 - ty1) - intersection + eps
  239. # IoU
  240. ious = 1 - (intersection / union)
  241. # Smooth-EIoU
  242. smooth_sign = (ious < smooth_point).detach().float()
  243. loss = 0.5 * smooth_sign * (ious**2) / smooth_point + (1 - smooth_sign) * (
  244. ious - 0.5 * smooth_point)
  245. return loss
  246. @MODELS.register_module()
  247. class IoULoss(nn.Module):
  248. """IoULoss.
  249. Computing the IoU loss between a set of predicted bboxes and target bboxes.
  250. Args:
  251. linear (bool): If True, use linear scale of loss else determined
  252. by mode. Default: False.
  253. eps (float): Epsilon to avoid log(0).
  254. reduction (str): Options are "none", "mean" and "sum".
  255. loss_weight (float): Weight of loss.
  256. mode (str): Loss scaling mode, including "linear", "square", and "log".
  257. Default: 'log'
  258. """
  259. def __init__(self,
  260. linear: bool = False,
  261. eps: float = 1e-6,
  262. reduction: str = 'mean',
  263. loss_weight: float = 1.0,
  264. mode: str = 'log') -> None:
  265. super().__init__()
  266. assert mode in ['linear', 'square', 'log']
  267. if linear:
  268. mode = 'linear'
  269. warnings.warn('DeprecationWarning: Setting "linear=True" in '
  270. 'IOULoss is deprecated, please use "mode=`linear`" '
  271. 'instead.')
  272. self.mode = mode
  273. self.linear = linear
  274. self.eps = eps
  275. self.reduction = reduction
  276. self.loss_weight = loss_weight
  277. def forward(self,
  278. pred: Tensor,
  279. target: Tensor,
  280. weight: Optional[Tensor] = None,
  281. avg_factor: Optional[int] = None,
  282. reduction_override: Optional[str] = None,
  283. **kwargs) -> Tensor:
  284. """Forward function.
  285. Args:
  286. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  287. shape (n, 4).
  288. target (Tensor): The learning target of the prediction,
  289. shape (n, 4).
  290. weight (Tensor, optional): The weight of loss for each
  291. prediction. Defaults to None.
  292. avg_factor (int, optional): Average factor that is used to average
  293. the loss. Defaults to None.
  294. reduction_override (str, optional): The reduction method used to
  295. override the original reduction method of the loss.
  296. Defaults to None. Options are "none", "mean" and "sum".
  297. Return:
  298. Tensor: Loss tensor.
  299. """
  300. assert reduction_override in (None, 'none', 'mean', 'sum')
  301. reduction = (
  302. reduction_override if reduction_override else self.reduction)
  303. if (weight is not None) and (not torch.any(weight > 0)) and (
  304. reduction != 'none'):
  305. if pred.dim() == weight.dim() + 1:
  306. weight = weight.unsqueeze(1)
  307. return (pred * weight).sum() # 0
  308. if weight is not None and weight.dim() > 1:
  309. # TODO: remove this in the future
  310. # reduce the weight of shape (n, 4) to (n,) to match the
  311. # iou_loss of shape (n,)
  312. assert weight.shape == pred.shape
  313. weight = weight.mean(-1)
  314. loss = self.loss_weight * iou_loss(
  315. pred,
  316. target,
  317. weight,
  318. mode=self.mode,
  319. eps=self.eps,
  320. reduction=reduction,
  321. avg_factor=avg_factor,
  322. **kwargs)
  323. return loss
  324. @MODELS.register_module()
  325. class BoundedIoULoss(nn.Module):
  326. """BIoULoss.
  327. This is an implementation of paper
  328. `Improving Object Localization with Fitness NMS and Bounded IoU Loss.
  329. <https://arxiv.org/abs/1711.00164>`_.
  330. Args:
  331. beta (float, optional): Beta parameter in smoothl1.
  332. eps (float, optional): Epsilon to avoid NaN values.
  333. reduction (str): Options are "none", "mean" and "sum".
  334. loss_weight (float): Weight of loss.
  335. """
  336. def __init__(self,
  337. beta: float = 0.2,
  338. eps: float = 1e-3,
  339. reduction: str = 'mean',
  340. loss_weight: float = 1.0) -> None:
  341. super().__init__()
  342. self.beta = beta
  343. self.eps = eps
  344. self.reduction = reduction
  345. self.loss_weight = loss_weight
  346. def forward(self,
  347. pred: Tensor,
  348. target: Tensor,
  349. weight: Optional[Tensor] = None,
  350. avg_factor: Optional[int] = None,
  351. reduction_override: Optional[str] = None,
  352. **kwargs) -> Tensor:
  353. """Forward function.
  354. Args:
  355. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  356. shape (n, 4).
  357. target (Tensor): The learning target of the prediction,
  358. shape (n, 4).
  359. weight (Optional[Tensor], optional): The weight of loss for each
  360. prediction. Defaults to None.
  361. avg_factor (Optional[int], optional): Average factor that is used
  362. to average the loss. Defaults to None.
  363. reduction_override (Optional[str], optional): The reduction method
  364. used to override the original reduction method of the loss.
  365. Defaults to None. Options are "none", "mean" and "sum".
  366. Returns:
  367. Tensor: Loss tensor.
  368. """
  369. if weight is not None and not torch.any(weight > 0):
  370. if pred.dim() == weight.dim() + 1:
  371. weight = weight.unsqueeze(1)
  372. return (pred * weight).sum() # 0
  373. assert reduction_override in (None, 'none', 'mean', 'sum')
  374. reduction = (
  375. reduction_override if reduction_override else self.reduction)
  376. loss = self.loss_weight * bounded_iou_loss(
  377. pred,
  378. target,
  379. weight,
  380. beta=self.beta,
  381. eps=self.eps,
  382. reduction=reduction,
  383. avg_factor=avg_factor,
  384. **kwargs)
  385. return loss
  386. @MODELS.register_module()
  387. class GIoULoss(nn.Module):
  388. r"""`Generalized Intersection over Union: A Metric and A Loss for Bounding
  389. Box Regression <https://arxiv.org/abs/1902.09630>`_.
  390. Args:
  391. eps (float): Epsilon to avoid log(0).
  392. reduction (str): Options are "none", "mean" and "sum".
  393. loss_weight (float): Weight of loss.
  394. """
  395. def __init__(self,
  396. eps: float = 1e-6,
  397. reduction: str = 'mean',
  398. loss_weight: float = 1.0) -> None:
  399. super().__init__()
  400. self.eps = eps
  401. self.reduction = reduction
  402. self.loss_weight = loss_weight
  403. def forward(self,
  404. pred: Tensor,
  405. target: Tensor,
  406. weight: Optional[Tensor] = None,
  407. avg_factor: Optional[int] = None,
  408. reduction_override: Optional[str] = None,
  409. **kwargs) -> Tensor:
  410. """Forward function.
  411. Args:
  412. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  413. shape (n, 4).
  414. target (Tensor): The learning target of the prediction,
  415. shape (n, 4).
  416. weight (Optional[Tensor], optional): The weight of loss for each
  417. prediction. Defaults to None.
  418. avg_factor (Optional[int], optional): Average factor that is used
  419. to average the loss. Defaults to None.
  420. reduction_override (Optional[str], optional): The reduction method
  421. used to override the original reduction method of the loss.
  422. Defaults to None. Options are "none", "mean" and "sum".
  423. Returns:
  424. Tensor: Loss tensor.
  425. """
  426. if weight is not None and not torch.any(weight > 0):
  427. if pred.dim() == weight.dim() + 1:
  428. weight = weight.unsqueeze(1)
  429. return (pred * weight).sum() # 0
  430. assert reduction_override in (None, 'none', 'mean', 'sum')
  431. reduction = (
  432. reduction_override if reduction_override else self.reduction)
  433. if weight is not None and weight.dim() > 1:
  434. # TODO: remove this in the future
  435. # reduce the weight of shape (n, 4) to (n,) to match the
  436. # giou_loss of shape (n,)
  437. assert weight.shape == pred.shape
  438. weight = weight.mean(-1)
  439. loss = self.loss_weight * giou_loss(
  440. pred,
  441. target,
  442. weight,
  443. eps=self.eps,
  444. reduction=reduction,
  445. avg_factor=avg_factor,
  446. **kwargs)
  447. return loss
  448. @MODELS.register_module()
  449. class DIoULoss(nn.Module):
  450. r"""Implementation of `Distance-IoU Loss: Faster and Better
  451. Learning for Bounding Box Regression https://arxiv.org/abs/1911.08287`_.
  452. Code is modified from https://github.com/Zzh-tju/DIoU.
  453. Args:
  454. eps (float): Epsilon to avoid log(0).
  455. reduction (str): Options are "none", "mean" and "sum".
  456. loss_weight (float): Weight of loss.
  457. """
  458. def __init__(self,
  459. eps: float = 1e-6,
  460. reduction: str = 'mean',
  461. loss_weight: float = 1.0) -> None:
  462. super().__init__()
  463. self.eps = eps
  464. self.reduction = reduction
  465. self.loss_weight = loss_weight
  466. def forward(self,
  467. pred: Tensor,
  468. target: Tensor,
  469. weight: Optional[Tensor] = None,
  470. avg_factor: Optional[int] = None,
  471. reduction_override: Optional[str] = None,
  472. **kwargs) -> Tensor:
  473. """Forward function.
  474. Args:
  475. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  476. shape (n, 4).
  477. target (Tensor): The learning target of the prediction,
  478. shape (n, 4).
  479. weight (Optional[Tensor], optional): The weight of loss for each
  480. prediction. Defaults to None.
  481. avg_factor (Optional[int], optional): Average factor that is used
  482. to average the loss. Defaults to None.
  483. reduction_override (Optional[str], optional): The reduction method
  484. used to override the original reduction method of the loss.
  485. Defaults to None. Options are "none", "mean" and "sum".
  486. Returns:
  487. Tensor: Loss tensor.
  488. """
  489. if weight is not None and not torch.any(weight > 0):
  490. if pred.dim() == weight.dim() + 1:
  491. weight = weight.unsqueeze(1)
  492. return (pred * weight).sum() # 0
  493. assert reduction_override in (None, 'none', 'mean', 'sum')
  494. reduction = (
  495. reduction_override if reduction_override else self.reduction)
  496. if weight is not None and weight.dim() > 1:
  497. # TODO: remove this in the future
  498. # reduce the weight of shape (n, 4) to (n,) to match the
  499. # giou_loss of shape (n,)
  500. assert weight.shape == pred.shape
  501. weight = weight.mean(-1)
  502. loss = self.loss_weight * diou_loss(
  503. pred,
  504. target,
  505. weight,
  506. eps=self.eps,
  507. reduction=reduction,
  508. avg_factor=avg_factor,
  509. **kwargs)
  510. return loss
  511. @MODELS.register_module()
  512. class CIoULoss(nn.Module):
  513. r"""`Implementation of paper `Enhancing Geometric Factors into
  514. Model Learning and Inference for Object Detection and Instance
  515. Segmentation <https://arxiv.org/abs/2005.03572>`_.
  516. Code is modified from https://github.com/Zzh-tju/CIoU.
  517. Args:
  518. eps (float): Epsilon to avoid log(0).
  519. reduction (str): Options are "none", "mean" and "sum".
  520. loss_weight (float): Weight of loss.
  521. """
  522. def __init__(self,
  523. eps: float = 1e-6,
  524. reduction: str = 'mean',
  525. loss_weight: float = 1.0) -> None:
  526. super().__init__()
  527. self.eps = eps
  528. self.reduction = reduction
  529. self.loss_weight = loss_weight
  530. def forward(self,
  531. pred: Tensor,
  532. target: Tensor,
  533. weight: Optional[Tensor] = None,
  534. avg_factor: Optional[int] = None,
  535. reduction_override: Optional[str] = None,
  536. **kwargs) -> Tensor:
  537. """Forward function.
  538. Args:
  539. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  540. shape (n, 4).
  541. target (Tensor): The learning target of the prediction,
  542. shape (n, 4).
  543. weight (Optional[Tensor], optional): The weight of loss for each
  544. prediction. Defaults to None.
  545. avg_factor (Optional[int], optional): Average factor that is used
  546. to average the loss. Defaults to None.
  547. reduction_override (Optional[str], optional): The reduction method
  548. used to override the original reduction method of the loss.
  549. Defaults to None. Options are "none", "mean" and "sum".
  550. Returns:
  551. Tensor: Loss tensor.
  552. """
  553. if weight is not None and not torch.any(weight > 0):
  554. if pred.dim() == weight.dim() + 1:
  555. weight = weight.unsqueeze(1)
  556. return (pred * weight).sum() # 0
  557. assert reduction_override in (None, 'none', 'mean', 'sum')
  558. reduction = (
  559. reduction_override if reduction_override else self.reduction)
  560. if weight is not None and weight.dim() > 1:
  561. # TODO: remove this in the future
  562. # reduce the weight of shape (n, 4) to (n,) to match the
  563. # giou_loss of shape (n,)
  564. assert weight.shape == pred.shape
  565. weight = weight.mean(-1)
  566. loss = self.loss_weight * ciou_loss(
  567. pred,
  568. target,
  569. weight,
  570. eps=self.eps,
  571. reduction=reduction,
  572. avg_factor=avg_factor,
  573. **kwargs)
  574. return loss
  575. @MODELS.register_module()
  576. class EIoULoss(nn.Module):
  577. r"""Implementation of paper `Extended-IoU Loss: A Systematic
  578. IoU-Related Method: Beyond Simplified Regression for Better
  579. Localization <https://ieeexplore.ieee.org/abstract/document/9429909>`_
  580. Code is modified from https://github.com//ShiqiYu/libfacedetection.train.
  581. Args:
  582. eps (float): Epsilon to avoid log(0).
  583. reduction (str): Options are "none", "mean" and "sum".
  584. loss_weight (float): Weight of loss.
  585. smooth_point (float): hyperparameter, default is 0.1.
  586. """
  587. def __init__(self,
  588. eps: float = 1e-6,
  589. reduction: str = 'mean',
  590. loss_weight: float = 1.0,
  591. smooth_point: float = 0.1) -> None:
  592. super().__init__()
  593. self.eps = eps
  594. self.reduction = reduction
  595. self.loss_weight = loss_weight
  596. self.smooth_point = smooth_point
  597. def forward(self,
  598. pred: Tensor,
  599. target: Tensor,
  600. weight: Optional[Tensor] = None,
  601. avg_factor: Optional[int] = None,
  602. reduction_override: Optional[str] = None,
  603. **kwargs) -> Tensor:
  604. """Forward function.
  605. Args:
  606. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  607. shape (n, 4).
  608. target (Tensor): The learning target of the prediction,
  609. shape (n, 4).
  610. weight (Optional[Tensor], optional): The weight of loss for each
  611. prediction. Defaults to None.
  612. avg_factor (Optional[int], optional): Average factor that is used
  613. to average the loss. Defaults to None.
  614. reduction_override (Optional[str], optional): The reduction method
  615. used to override the original reduction method of the loss.
  616. Defaults to None. Options are "none", "mean" and "sum".
  617. Returns:
  618. Tensor: Loss tensor.
  619. """
  620. if weight is not None and not torch.any(weight > 0):
  621. if pred.dim() == weight.dim() + 1:
  622. weight = weight.unsqueeze(1)
  623. return (pred * weight).sum() # 0
  624. assert reduction_override in (None, 'none', 'mean', 'sum')
  625. reduction = (
  626. reduction_override if reduction_override else self.reduction)
  627. if weight is not None and weight.dim() > 1:
  628. assert weight.shape == pred.shape
  629. weight = weight.mean(-1)
  630. loss = self.loss_weight * eiou_loss(
  631. pred,
  632. target,
  633. weight,
  634. smooth_point=self.smooth_point,
  635. eps=self.eps,
  636. reduction=reduction,
  637. avg_factor=avg_factor,
  638. **kwargs)
  639. return loss