decoder.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. # Copyright (c) Tianheng Cheng and its affiliates. All Rights Reserved
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmengine.model.weight_init import caffe2_xavier_init, kaiming_init
  7. from torch.nn import init
  8. from mmdet.registry import MODELS
  9. def _make_stack_3x3_convs(num_convs,
  10. in_channels,
  11. out_channels,
  12. act_cfg=dict(type='ReLU', inplace=True)):
  13. convs = []
  14. for _ in range(num_convs):
  15. convs.append(nn.Conv2d(in_channels, out_channels, 3, padding=1))
  16. convs.append(MODELS.build(act_cfg))
  17. in_channels = out_channels
  18. return nn.Sequential(*convs)
  19. class InstanceBranch(nn.Module):
  20. def __init__(self,
  21. in_channels,
  22. dim=256,
  23. num_convs=4,
  24. num_masks=100,
  25. num_classes=80,
  26. kernel_dim=128,
  27. act_cfg=dict(type='ReLU', inplace=True)):
  28. super().__init__()
  29. num_masks = num_masks
  30. self.num_classes = num_classes
  31. self.inst_convs = _make_stack_3x3_convs(num_convs, in_channels, dim,
  32. act_cfg)
  33. # iam prediction, a simple conv
  34. self.iam_conv = nn.Conv2d(dim, num_masks, 3, padding=1)
  35. # outputs
  36. self.cls_score = nn.Linear(dim, self.num_classes)
  37. self.mask_kernel = nn.Linear(dim, kernel_dim)
  38. self.objectness = nn.Linear(dim, 1)
  39. self.prior_prob = 0.01
  40. self._init_weights()
  41. def _init_weights(self):
  42. for m in self.inst_convs.modules():
  43. if isinstance(m, nn.Conv2d):
  44. kaiming_init(m)
  45. bias_value = -math.log((1 - self.prior_prob) / self.prior_prob)
  46. for module in [self.iam_conv, self.cls_score]:
  47. init.constant_(module.bias, bias_value)
  48. init.normal_(self.iam_conv.weight, std=0.01)
  49. init.normal_(self.cls_score.weight, std=0.01)
  50. init.normal_(self.mask_kernel.weight, std=0.01)
  51. init.constant_(self.mask_kernel.bias, 0.0)
  52. def forward(self, features):
  53. # instance features (x4 convs)
  54. features = self.inst_convs(features)
  55. # predict instance activation maps
  56. iam = self.iam_conv(features)
  57. iam_prob = iam.sigmoid()
  58. B, N = iam_prob.shape[:2]
  59. C = features.size(1)
  60. # BxNxHxW -> BxNx(HW)
  61. iam_prob = iam_prob.view(B, N, -1)
  62. normalizer = iam_prob.sum(-1).clamp(min=1e-6)
  63. iam_prob = iam_prob / normalizer[:, :, None]
  64. # aggregate features: BxCxHxW -> Bx(HW)xC
  65. inst_features = torch.bmm(iam_prob,
  66. features.view(B, C, -1).permute(0, 2, 1))
  67. # predict classification & segmentation kernel & objectness
  68. pred_logits = self.cls_score(inst_features)
  69. pred_kernel = self.mask_kernel(inst_features)
  70. pred_scores = self.objectness(inst_features)
  71. return pred_logits, pred_kernel, pred_scores, iam
  72. class MaskBranch(nn.Module):
  73. def __init__(self,
  74. in_channels,
  75. dim=256,
  76. num_convs=4,
  77. kernel_dim=128,
  78. act_cfg=dict(type='ReLU', inplace=True)):
  79. super().__init__()
  80. self.mask_convs = _make_stack_3x3_convs(num_convs, in_channels, dim,
  81. act_cfg)
  82. self.projection = nn.Conv2d(dim, kernel_dim, kernel_size=1)
  83. self._init_weights()
  84. def _init_weights(self):
  85. for m in self.mask_convs.modules():
  86. if isinstance(m, nn.Conv2d):
  87. kaiming_init(m)
  88. kaiming_init(self.projection)
  89. def forward(self, features):
  90. # mask features (x4 convs)
  91. features = self.mask_convs(features)
  92. return self.projection(features)
  93. @MODELS.register_module()
  94. class BaseIAMDecoder(nn.Module):
  95. def __init__(self,
  96. in_channels,
  97. num_classes,
  98. ins_dim=256,
  99. ins_conv=4,
  100. mask_dim=256,
  101. mask_conv=4,
  102. kernel_dim=128,
  103. scale_factor=2.0,
  104. output_iam=False,
  105. num_masks=100,
  106. act_cfg=dict(type='ReLU', inplace=True)):
  107. super().__init__()
  108. # add 2 for coordinates
  109. in_channels = in_channels # ENCODER.NUM_CHANNELS + 2
  110. self.scale_factor = scale_factor
  111. self.output_iam = output_iam
  112. self.inst_branch = InstanceBranch(
  113. in_channels,
  114. dim=ins_dim,
  115. num_convs=ins_conv,
  116. num_masks=num_masks,
  117. num_classes=num_classes,
  118. kernel_dim=kernel_dim,
  119. act_cfg=act_cfg)
  120. self.mask_branch = MaskBranch(
  121. in_channels,
  122. dim=mask_dim,
  123. num_convs=mask_conv,
  124. kernel_dim=kernel_dim,
  125. act_cfg=act_cfg)
  126. @torch.no_grad()
  127. def compute_coordinates_linspace(self, x):
  128. # linspace is not supported in ONNX
  129. h, w = x.size(2), x.size(3)
  130. y_loc = torch.linspace(-1, 1, h, device=x.device)
  131. x_loc = torch.linspace(-1, 1, w, device=x.device)
  132. y_loc, x_loc = torch.meshgrid(y_loc, x_loc)
  133. y_loc = y_loc.expand([x.shape[0], 1, -1, -1])
  134. x_loc = x_loc.expand([x.shape[0], 1, -1, -1])
  135. locations = torch.cat([x_loc, y_loc], 1)
  136. return locations.to(x)
  137. @torch.no_grad()
  138. def compute_coordinates(self, x):
  139. h, w = x.size(2), x.size(3)
  140. y_loc = -1.0 + 2.0 * torch.arange(h, device=x.device) / (h - 1)
  141. x_loc = -1.0 + 2.0 * torch.arange(w, device=x.device) / (w - 1)
  142. y_loc, x_loc = torch.meshgrid(y_loc, x_loc)
  143. y_loc = y_loc.expand([x.shape[0], 1, -1, -1])
  144. x_loc = x_loc.expand([x.shape[0], 1, -1, -1])
  145. locations = torch.cat([x_loc, y_loc], 1)
  146. return locations.to(x)
  147. def forward(self, features):
  148. coord_features = self.compute_coordinates(features)
  149. features = torch.cat([coord_features, features], dim=1)
  150. pred_logits, pred_kernel, pred_scores, iam = self.inst_branch(features)
  151. mask_features = self.mask_branch(features)
  152. N = pred_kernel.shape[1]
  153. # mask_features: BxCxHxW
  154. B, C, H, W = mask_features.shape
  155. pred_masks = torch.bmm(pred_kernel,
  156. mask_features.view(B, C,
  157. H * W)).view(B, N, H, W)
  158. pred_masks = F.interpolate(
  159. pred_masks,
  160. scale_factor=self.scale_factor,
  161. mode='bilinear',
  162. align_corners=False)
  163. output = {
  164. 'pred_logits': pred_logits,
  165. 'pred_masks': pred_masks,
  166. 'pred_scores': pred_scores,
  167. }
  168. if self.output_iam:
  169. iam = F.interpolate(
  170. iam,
  171. scale_factor=self.scale_factor,
  172. mode='bilinear',
  173. align_corners=False)
  174. output['pred_iam'] = iam
  175. return output
  176. class GroupInstanceBranch(nn.Module):
  177. def __init__(self,
  178. in_channels,
  179. num_groups=4,
  180. dim=256,
  181. num_convs=4,
  182. num_masks=100,
  183. num_classes=80,
  184. kernel_dim=128,
  185. act_cfg=dict(type='ReLU', inplace=True)):
  186. super().__init__()
  187. self.num_groups = num_groups
  188. self.num_classes = num_classes
  189. self.inst_convs = _make_stack_3x3_convs(
  190. num_convs, in_channels, dim, act_cfg=act_cfg)
  191. # iam prediction, a group conv
  192. expand_dim = dim * self.num_groups
  193. self.iam_conv = nn.Conv2d(
  194. dim,
  195. num_masks * self.num_groups,
  196. 3,
  197. padding=1,
  198. groups=self.num_groups)
  199. # outputs
  200. self.fc = nn.Linear(expand_dim, expand_dim)
  201. self.cls_score = nn.Linear(expand_dim, self.num_classes)
  202. self.mask_kernel = nn.Linear(expand_dim, kernel_dim)
  203. self.objectness = nn.Linear(expand_dim, 1)
  204. self.prior_prob = 0.01
  205. self._init_weights()
  206. def _init_weights(self):
  207. for m in self.inst_convs.modules():
  208. if isinstance(m, nn.Conv2d):
  209. kaiming_init(m)
  210. bias_value = -math.log((1 - self.prior_prob) / self.prior_prob)
  211. for module in [self.iam_conv, self.cls_score]:
  212. init.constant_(module.bias, bias_value)
  213. init.normal_(self.iam_conv.weight, std=0.01)
  214. init.normal_(self.cls_score.weight, std=0.01)
  215. init.normal_(self.mask_kernel.weight, std=0.01)
  216. init.constant_(self.mask_kernel.bias, 0.0)
  217. caffe2_xavier_init(self.fc)
  218. def forward(self, features):
  219. # instance features (x4 convs)
  220. features = self.inst_convs(features)
  221. # predict instance activation maps
  222. iam = self.iam_conv(features)
  223. iam_prob = iam.sigmoid()
  224. B, N = iam_prob.shape[:2]
  225. C = features.size(1)
  226. # BxNxHxW -> BxNx(HW)
  227. iam_prob = iam_prob.view(B, N, -1)
  228. normalizer = iam_prob.sum(-1).clamp(min=1e-6)
  229. iam_prob = iam_prob / normalizer[:, :, None]
  230. # aggregate features: BxCxHxW -> Bx(HW)xC
  231. inst_features = torch.bmm(iam_prob,
  232. features.view(B, C, -1).permute(0, 2, 1))
  233. inst_features = inst_features.reshape(B, 4, N // self.num_groups,
  234. -1).transpose(1, 2).reshape(
  235. B, N // self.num_groups, -1)
  236. inst_features = F.relu_(self.fc(inst_features))
  237. # predict classification & segmentation kernel & objectness
  238. pred_logits = self.cls_score(inst_features)
  239. pred_kernel = self.mask_kernel(inst_features)
  240. pred_scores = self.objectness(inst_features)
  241. return pred_logits, pred_kernel, pred_scores, iam
  242. @MODELS.register_module()
  243. class GroupIAMDecoder(BaseIAMDecoder):
  244. def __init__(self,
  245. in_channels,
  246. num_classes,
  247. num_groups=4,
  248. ins_dim=256,
  249. ins_conv=4,
  250. mask_dim=256,
  251. mask_conv=4,
  252. kernel_dim=128,
  253. scale_factor=2.0,
  254. output_iam=False,
  255. num_masks=100,
  256. act_cfg=dict(type='ReLU', inplace=True)):
  257. super().__init__(
  258. in_channels=in_channels,
  259. num_classes=num_classes,
  260. ins_dim=ins_dim,
  261. ins_conv=ins_conv,
  262. mask_dim=mask_dim,
  263. mask_conv=mask_conv,
  264. kernel_dim=kernel_dim,
  265. scale_factor=scale_factor,
  266. output_iam=output_iam,
  267. num_masks=num_masks,
  268. act_cfg=act_cfg)
  269. self.inst_branch = GroupInstanceBranch(
  270. in_channels,
  271. num_groups=num_groups,
  272. dim=ins_dim,
  273. num_convs=ins_conv,
  274. num_masks=num_masks,
  275. num_classes=num_classes,
  276. kernel_dim=kernel_dim,
  277. act_cfg=act_cfg)
  278. class GroupInstanceSoftBranch(GroupInstanceBranch):
  279. def __init__(self, *args, **kwargs):
  280. super().__init__(*args, **kwargs)
  281. self.softmax_bias = nn.Parameter(torch.ones([
  282. 1,
  283. ]))
  284. def forward(self, features):
  285. # instance features (x4 convs)
  286. features = self.inst_convs(features)
  287. # predict instance activation maps
  288. iam = self.iam_conv(features)
  289. B, N = iam.shape[:2]
  290. C = features.size(1)
  291. # BxNxHxW -> BxNx(HW)
  292. iam_prob = F.softmax(iam.view(B, N, -1) + self.softmax_bias, dim=-1)
  293. # aggregate features: BxCxHxW -> Bx(HW)xC
  294. inst_features = torch.bmm(iam_prob,
  295. features.view(B, C, -1).permute(0, 2, 1))
  296. inst_features = inst_features.reshape(B, self.num_groups,
  297. N // self.num_groups,
  298. -1).transpose(1, 2).reshape(
  299. B, N // self.num_groups, -1)
  300. inst_features = F.relu_(self.fc(inst_features))
  301. # predict classification & segmentation kernel & objectness
  302. pred_logits = self.cls_score(inst_features)
  303. pred_kernel = self.mask_kernel(inst_features)
  304. pred_scores = self.objectness(inst_features)
  305. return pred_logits, pred_kernel, pred_scores, iam
  306. @MODELS.register_module()
  307. class GroupIAMSoftDecoder(BaseIAMDecoder):
  308. def __init__(self,
  309. in_channels,
  310. num_classes,
  311. num_groups=4,
  312. ins_dim=256,
  313. ins_conv=4,
  314. mask_dim=256,
  315. mask_conv=4,
  316. kernel_dim=128,
  317. scale_factor=2.0,
  318. output_iam=False,
  319. num_masks=100,
  320. act_cfg=dict(type='ReLU', inplace=True)):
  321. super().__init__(
  322. in_channels=in_channels,
  323. num_classes=num_classes,
  324. ins_dim=ins_dim,
  325. ins_conv=ins_conv,
  326. mask_dim=mask_dim,
  327. mask_conv=mask_conv,
  328. kernel_dim=kernel_dim,
  329. scale_factor=scale_factor,
  330. output_iam=output_iam,
  331. num_masks=num_masks,
  332. act_cfg=act_cfg)
  333. self.inst_branch = GroupInstanceSoftBranch(
  334. in_channels,
  335. num_groups=num_groups,
  336. dim=ins_dim,
  337. num_convs=ins_conv,
  338. num_masks=num_masks,
  339. num_classes=num_classes,
  340. kernel_dim=kernel_dim,
  341. act_cfg=act_cfg)