ae_loss.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from mmdet.registry import MODELS
  6. def ae_loss_per_image(tl_preds, br_preds, match):
  7. """Associative Embedding Loss in one image.
  8. Associative Embedding Loss including two parts: pull loss and push loss.
  9. Pull loss makes embedding vectors from same object closer to each other.
  10. Push loss distinguish embedding vector from different objects, and makes
  11. the gap between them is large enough.
  12. During computing, usually there are 3 cases:
  13. - no object in image: both pull loss and push loss will be 0.
  14. - one object in image: push loss will be 0 and pull loss is computed
  15. by the two corner of the only object.
  16. - more than one objects in image: pull loss is computed by corner pairs
  17. from each object, push loss is computed by each object with all
  18. other objects. We use confusion matrix with 0 in diagonal to
  19. compute the push loss.
  20. Args:
  21. tl_preds (tensor): Embedding feature map of left-top corner.
  22. br_preds (tensor): Embedding feature map of bottim-right corner.
  23. match (list): Downsampled coordinates pair of each ground truth box.
  24. """
  25. tl_list, br_list, me_list = [], [], []
  26. if len(match) == 0: # no object in image
  27. pull_loss = tl_preds.sum() * 0.
  28. push_loss = tl_preds.sum() * 0.
  29. else:
  30. for m in match:
  31. [tl_y, tl_x], [br_y, br_x] = m
  32. tl_e = tl_preds[:, tl_y, tl_x].view(-1, 1)
  33. br_e = br_preds[:, br_y, br_x].view(-1, 1)
  34. tl_list.append(tl_e)
  35. br_list.append(br_e)
  36. me_list.append((tl_e + br_e) / 2.0)
  37. tl_list = torch.cat(tl_list)
  38. br_list = torch.cat(br_list)
  39. me_list = torch.cat(me_list)
  40. assert tl_list.size() == br_list.size()
  41. # N is object number in image, M is dimension of embedding vector
  42. N, M = tl_list.size()
  43. pull_loss = (tl_list - me_list).pow(2) + (br_list - me_list).pow(2)
  44. pull_loss = pull_loss.sum() / N
  45. margin = 1 # exp setting of CornerNet, details in section 3.3 of paper
  46. # confusion matrix of push loss
  47. conf_mat = me_list.expand((N, N, M)).permute(1, 0, 2) - me_list
  48. conf_weight = 1 - torch.eye(N).type_as(me_list)
  49. conf_mat = conf_weight * (margin - conf_mat.sum(-1).abs())
  50. if N > 1: # more than one object in current image
  51. push_loss = F.relu(conf_mat).sum() / (N * (N - 1))
  52. else:
  53. push_loss = tl_preds.sum() * 0.
  54. return pull_loss, push_loss
  55. @MODELS.register_module()
  56. class AssociativeEmbeddingLoss(nn.Module):
  57. """Associative Embedding Loss.
  58. More details can be found in
  59. `Associative Embedding <https://arxiv.org/abs/1611.05424>`_ and
  60. `CornerNet <https://arxiv.org/abs/1808.01244>`_ .
  61. Code is modified from `kp_utils.py <https://github.com/princeton-vl/CornerNet/blob/master/models/py_utils/kp_utils.py#L180>`_ # noqa: E501
  62. Args:
  63. pull_weight (float): Loss weight for corners from same object.
  64. push_weight (float): Loss weight for corners from different object.
  65. """
  66. def __init__(self, pull_weight=0.25, push_weight=0.25):
  67. super(AssociativeEmbeddingLoss, self).__init__()
  68. self.pull_weight = pull_weight
  69. self.push_weight = push_weight
  70. def forward(self, pred, target, match):
  71. """Forward function."""
  72. batch = pred.size(0)
  73. pull_all, push_all = 0.0, 0.0
  74. for i in range(batch):
  75. pull, push = ae_loss_per_image(pred[i], target[i], match[i])
  76. pull_all += self.pull_weight * pull
  77. push_all += self.push_weight * push
  78. return pull_all, push_all