head.py 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  3. # Modified from https://github.com/ShoufaChen/DiffusionDet/blob/main/diffusiondet/detector.py # noqa
  4. # Modified from https://github.com/ShoufaChen/DiffusionDet/blob/main/diffusiondet/head.py # noqa
  5. # This work is licensed under the CC-BY-NC 4.0 License.
  6. # Users should be careful about adopting these features in any commercial matters. # noqa
  7. # For more details, please refer to https://github.com/ShoufaChen/DiffusionDet/blob/main/LICENSE # noqa
  8. import copy
  9. import math
  10. import random
  11. import warnings
  12. from typing import Tuple
  13. import torch
  14. import torch.nn as nn
  15. import torch.nn.functional as F
  16. from mmcv.cnn import build_activation_layer
  17. from mmcv.ops import batched_nms
  18. from mmengine.structures import InstanceData
  19. from torch import Tensor
  20. from mmdet.registry import MODELS, TASK_UTILS
  21. from mmdet.structures import SampleList
  22. from mmdet.structures.bbox import (bbox2roi, bbox_cxcywh_to_xyxy,
  23. bbox_xyxy_to_cxcywh, get_box_wh,
  24. scale_boxes)
  25. from mmdet.utils import InstanceList
  26. _DEFAULT_SCALE_CLAMP = math.log(100000.0 / 16)
  27. def cosine_beta_schedule(timesteps, s=0.008):
  28. """Cosine schedule as proposed in
  29. https://openreview.net/forum?id=-NEXDKk8gZ."""
  30. steps = timesteps + 1
  31. x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
  32. alphas_cumprod = torch.cos(
  33. ((x / timesteps) + s) / (1 + s) * math.pi * 0.5)**2
  34. alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
  35. betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
  36. return torch.clip(betas, 0, 0.999)
  37. def extract(a, t, x_shape):
  38. """extract the appropriate t index for a batch of indices."""
  39. batch_size = t.shape[0]
  40. out = a.gather(-1, t)
  41. return out.reshape(batch_size, *((1, ) * (len(x_shape) - 1)))
  42. class SinusoidalPositionEmbeddings(nn.Module):
  43. def __init__(self, dim):
  44. super().__init__()
  45. self.dim = dim
  46. def forward(self, time):
  47. device = time.device
  48. half_dim = self.dim // 2
  49. embeddings = math.log(10000) / (half_dim - 1)
  50. embeddings = torch.exp(
  51. torch.arange(half_dim, device=device) * -embeddings)
  52. embeddings = time[:, None] * embeddings[None, :]
  53. embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
  54. return embeddings
  55. @MODELS.register_module()
  56. class DynamicDiffusionDetHead(nn.Module):
  57. def __init__(self,
  58. num_classes=80,
  59. feat_channels=256,
  60. num_proposals=500,
  61. num_heads=6,
  62. prior_prob=0.01,
  63. snr_scale=2.0,
  64. timesteps=1000,
  65. sampling_timesteps=1,
  66. self_condition=False,
  67. box_renewal=True,
  68. use_ensemble=True,
  69. deep_supervision=True,
  70. ddim_sampling_eta=1.0,
  71. criterion=dict(
  72. type='DiffusionDetCriterion',
  73. num_classes=80,
  74. assigner=dict(
  75. type='DiffusionDetMatcher',
  76. match_costs=[
  77. dict(
  78. type='FocalLossCost',
  79. alpha=2.0,
  80. gamma=0.25,
  81. weight=2.0),
  82. dict(
  83. type='BBoxL1Cost',
  84. weight=5.0,
  85. box_format='xyxy'),
  86. dict(type='IoUCost', iou_mode='giou', weight=2.0)
  87. ],
  88. center_radius=2.5,
  89. candidate_topk=5),
  90. ),
  91. single_head=dict(
  92. type='DiffusionDetHead',
  93. num_cls_convs=1,
  94. num_reg_convs=3,
  95. dim_feedforward=2048,
  96. num_heads=8,
  97. dropout=0.0,
  98. act_cfg=dict(type='ReLU'),
  99. dynamic_conv=dict(dynamic_dim=64, dynamic_num=2)),
  100. roi_extractor=dict(
  101. type='SingleRoIExtractor',
  102. roi_layer=dict(
  103. type='RoIAlign', output_size=7, sampling_ratio=2),
  104. out_channels=256,
  105. featmap_strides=[4, 8, 16, 32]),
  106. test_cfg=None,
  107. **kwargs) -> None:
  108. super().__init__()
  109. self.roi_extractor = MODELS.build(roi_extractor)
  110. self.num_classes = num_classes
  111. self.num_classes = num_classes
  112. self.feat_channels = feat_channels
  113. self.num_proposals = num_proposals
  114. self.num_heads = num_heads
  115. # Build Diffusion
  116. assert isinstance(timesteps, int), 'The type of `timesteps` should ' \
  117. f'be int but got {type(timesteps)}'
  118. assert sampling_timesteps <= timesteps
  119. self.timesteps = timesteps
  120. self.sampling_timesteps = sampling_timesteps
  121. self.snr_scale = snr_scale
  122. self.ddim_sampling = self.sampling_timesteps < self.timesteps
  123. self.ddim_sampling_eta = ddim_sampling_eta
  124. self.self_condition = self_condition
  125. self.box_renewal = box_renewal
  126. self.use_ensemble = use_ensemble
  127. self._build_diffusion()
  128. # Build assigner
  129. assert criterion.get('assigner', None) is not None
  130. assigner = TASK_UTILS.build(criterion.get('assigner'))
  131. # Init parameters.
  132. self.use_focal_loss = assigner.use_focal_loss
  133. self.use_fed_loss = assigner.use_fed_loss
  134. # build criterion
  135. criterion.update(deep_supervision=deep_supervision)
  136. self.criterion = TASK_UTILS.build(criterion)
  137. # Build Dynamic Head.
  138. single_head_ = single_head.copy()
  139. single_head_num_classes = single_head_.get('num_classes', None)
  140. if single_head_num_classes is None:
  141. single_head_.update(num_classes=num_classes)
  142. else:
  143. if single_head_num_classes != num_classes:
  144. warnings.warn(
  145. 'The `num_classes` of `DynamicDiffusionDetHead` and '
  146. '`SingleDiffusionDetHead` should be same, changing '
  147. f'`single_head.num_classes` to {num_classes}')
  148. single_head_.update(num_classes=num_classes)
  149. single_head_feat_channels = single_head_.get('feat_channels', None)
  150. if single_head_feat_channels is None:
  151. single_head_.update(feat_channels=feat_channels)
  152. else:
  153. if single_head_feat_channels != feat_channels:
  154. warnings.warn(
  155. 'The `feat_channels` of `DynamicDiffusionDetHead` and '
  156. '`SingleDiffusionDetHead` should be same, changing '
  157. f'`single_head.feat_channels` to {feat_channels}')
  158. single_head_.update(feat_channels=feat_channels)
  159. default_pooler_resolution = roi_extractor['roi_layer'].get(
  160. 'output_size')
  161. assert default_pooler_resolution is not None
  162. single_head_pooler_resolution = single_head_.get('pooler_resolution')
  163. if single_head_pooler_resolution is None:
  164. single_head_.update(pooler_resolution=default_pooler_resolution)
  165. else:
  166. if single_head_pooler_resolution != default_pooler_resolution:
  167. warnings.warn(
  168. 'The `pooler_resolution` of `DynamicDiffusionDetHead` '
  169. 'and `SingleDiffusionDetHead` should be same, changing '
  170. f'`single_head.pooler_resolution` to {num_classes}')
  171. single_head_.update(
  172. pooler_resolution=default_pooler_resolution)
  173. single_head_.update(
  174. use_focal_loss=self.use_focal_loss, use_fed_loss=self.use_fed_loss)
  175. single_head_module = MODELS.build(single_head_)
  176. self.num_heads = num_heads
  177. self.head_series = nn.ModuleList(
  178. [copy.deepcopy(single_head_module) for _ in range(num_heads)])
  179. self.deep_supervision = deep_supervision
  180. # Gaussian random feature embedding layer for time
  181. time_dim = feat_channels * 4
  182. self.time_mlp = nn.Sequential(
  183. SinusoidalPositionEmbeddings(feat_channels),
  184. nn.Linear(feat_channels, time_dim), nn.GELU(),
  185. nn.Linear(time_dim, time_dim))
  186. self.prior_prob = prior_prob
  187. self.test_cfg = test_cfg
  188. self.use_nms = self.test_cfg.get('use_nms', True)
  189. self._init_weights()
  190. def _init_weights(self):
  191. # init all parameters.
  192. bias_value = -math.log((1 - self.prior_prob) / self.prior_prob)
  193. for p in self.parameters():
  194. if p.dim() > 1:
  195. nn.init.xavier_uniform_(p)
  196. # initialize the bias for focal loss and fed loss.
  197. if self.use_focal_loss or self.use_fed_loss:
  198. if p.shape[-1] == self.num_classes or \
  199. p.shape[-1] == self.num_classes + 1:
  200. nn.init.constant_(p, bias_value)
  201. def _build_diffusion(self):
  202. betas = cosine_beta_schedule(self.timesteps)
  203. alphas = 1. - betas
  204. alphas_cumprod = torch.cumprod(alphas, dim=0)
  205. alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)
  206. self.register_buffer('betas', betas)
  207. self.register_buffer('alphas_cumprod', alphas_cumprod)
  208. self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
  209. # calculations for diffusion q(x_t | x_{t-1}) and others
  210. self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
  211. self.register_buffer('sqrt_one_minus_alphas_cumprod',
  212. torch.sqrt(1. - alphas_cumprod))
  213. self.register_buffer('log_one_minus_alphas_cumprod',
  214. torch.log(1. - alphas_cumprod))
  215. self.register_buffer('sqrt_recip_alphas_cumprod',
  216. torch.sqrt(1. / alphas_cumprod))
  217. self.register_buffer('sqrt_recipm1_alphas_cumprod',
  218. torch.sqrt(1. / alphas_cumprod - 1))
  219. # calculations for posterior q(x_{t-1} | x_t, x_0)
  220. # equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
  221. posterior_variance = betas * (1. - alphas_cumprod_prev) / (
  222. 1. - alphas_cumprod)
  223. self.register_buffer('posterior_variance', posterior_variance)
  224. # log calculation clipped because the posterior variance is 0 at
  225. # the beginning of the diffusion chain
  226. self.register_buffer('posterior_log_variance_clipped',
  227. torch.log(posterior_variance.clamp(min=1e-20)))
  228. self.register_buffer(
  229. 'posterior_mean_coef1',
  230. betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
  231. self.register_buffer('posterior_mean_coef2',
  232. (1. - alphas_cumprod_prev) * torch.sqrt(alphas) /
  233. (1. - alphas_cumprod))
  234. def forward(self, features, init_bboxes, init_t, init_features=None):
  235. time = self.time_mlp(init_t, )
  236. inter_class_logits = []
  237. inter_pred_bboxes = []
  238. bs = len(features[0])
  239. bboxes = init_bboxes
  240. if init_features is not None:
  241. init_features = init_features[None].repeat(1, bs, 1)
  242. proposal_features = init_features.clone()
  243. else:
  244. proposal_features = None
  245. for head_idx, single_head in enumerate(self.head_series):
  246. class_logits, pred_bboxes, proposal_features = single_head(
  247. features, bboxes, proposal_features, self.roi_extractor, time)
  248. if self.deep_supervision:
  249. inter_class_logits.append(class_logits)
  250. inter_pred_bboxes.append(pred_bboxes)
  251. bboxes = pred_bboxes.detach()
  252. if self.deep_supervision:
  253. return torch.stack(inter_class_logits), torch.stack(
  254. inter_pred_bboxes)
  255. else:
  256. return class_logits[None, ...], pred_bboxes[None, ...]
  257. def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict:
  258. """Perform forward propagation and loss calculation of the detection
  259. head on the features of the upstream network.
  260. Args:
  261. x (tuple[Tensor]): Features from the upstream network, each is
  262. a 4D-tensor.
  263. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  264. Samples. It usually includes information such as
  265. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  266. Returns:
  267. dict: A dictionary of loss components.
  268. """
  269. prepare_outputs = self.prepare_training_targets(batch_data_samples)
  270. (batch_gt_instances, batch_pred_instances, batch_gt_instances_ignore,
  271. batch_img_metas) = prepare_outputs
  272. batch_diff_bboxes = torch.stack([
  273. pred_instances.diff_bboxes_abs
  274. for pred_instances in batch_pred_instances
  275. ])
  276. batch_time = torch.stack(
  277. [pred_instances.time for pred_instances in batch_pred_instances])
  278. pred_logits, pred_bboxes = self(x, batch_diff_bboxes, batch_time)
  279. output = {
  280. 'pred_logits': pred_logits[-1],
  281. 'pred_boxes': pred_bboxes[-1]
  282. }
  283. if self.deep_supervision:
  284. output['aux_outputs'] = [{
  285. 'pred_logits': a,
  286. 'pred_boxes': b
  287. } for a, b in zip(pred_logits[:-1], pred_bboxes[:-1])]
  288. losses = self.criterion(output, batch_gt_instances, batch_img_metas)
  289. return losses
  290. def prepare_training_targets(self, batch_data_samples):
  291. # hard-setting seed to keep results same (if necessary)
  292. # random.seed(0)
  293. # torch.manual_seed(0)
  294. # torch.cuda.manual_seed_all(0)
  295. # torch.backends.cudnn.deterministic = True
  296. # torch.backends.cudnn.benchmark = False
  297. batch_gt_instances = []
  298. batch_pred_instances = []
  299. batch_gt_instances_ignore = []
  300. batch_img_metas = []
  301. for data_sample in batch_data_samples:
  302. img_meta = data_sample.metainfo
  303. gt_instances = data_sample.gt_instances
  304. gt_bboxes = gt_instances.bboxes
  305. h, w = img_meta['img_shape']
  306. image_size = gt_bboxes.new_tensor([w, h, w, h])
  307. norm_gt_bboxes = gt_bboxes / image_size
  308. norm_gt_bboxes_cxcywh = bbox_xyxy_to_cxcywh(norm_gt_bboxes)
  309. pred_instances = self.prepare_diffusion(norm_gt_bboxes_cxcywh,
  310. image_size)
  311. gt_instances.set_metainfo(dict(image_size=image_size))
  312. gt_instances.norm_bboxes_cxcywh = norm_gt_bboxes_cxcywh
  313. batch_gt_instances.append(gt_instances)
  314. batch_pred_instances.append(pred_instances)
  315. batch_img_metas.append(data_sample.metainfo)
  316. if 'ignored_instances' in data_sample:
  317. batch_gt_instances_ignore.append(data_sample.ignored_instances)
  318. else:
  319. batch_gt_instances_ignore.append(None)
  320. return (batch_gt_instances, batch_pred_instances,
  321. batch_gt_instances_ignore, batch_img_metas)
  322. def prepare_diffusion(self, gt_boxes, image_size):
  323. device = gt_boxes.device
  324. time = torch.randint(
  325. 0, self.timesteps, (1, ), dtype=torch.long, device=device)
  326. noise = torch.randn(self.num_proposals, 4, device=device)
  327. num_gt = gt_boxes.shape[0]
  328. if num_gt < self.num_proposals:
  329. # 3 * sigma = 1/2 --> sigma: 1/6
  330. box_placeholder = torch.randn(
  331. self.num_proposals - num_gt, 4, device=device) / 6. + 0.5
  332. box_placeholder[:, 2:] = torch.clip(
  333. box_placeholder[:, 2:], min=1e-4)
  334. x_start = torch.cat((gt_boxes, box_placeholder), dim=0)
  335. else:
  336. select_mask = [True] * self.num_proposals + \
  337. [False] * (num_gt - self.num_proposals)
  338. random.shuffle(select_mask)
  339. x_start = gt_boxes[select_mask]
  340. x_start = (x_start * 2. - 1.) * self.snr_scale
  341. # noise sample
  342. x = self.q_sample(x_start=x_start, time=time, noise=noise)
  343. x = torch.clamp(x, min=-1 * self.snr_scale, max=self.snr_scale)
  344. x = ((x / self.snr_scale) + 1) / 2.
  345. diff_bboxes = bbox_cxcywh_to_xyxy(x)
  346. # convert to abs bboxes
  347. diff_bboxes_abs = diff_bboxes * image_size
  348. metainfo = dict(time=time.squeeze(-1))
  349. pred_instances = InstanceData(metainfo=metainfo)
  350. pred_instances.diff_bboxes = diff_bboxes
  351. pred_instances.diff_bboxes_abs = diff_bboxes_abs
  352. pred_instances.noise = noise
  353. return pred_instances
  354. # forward diffusion
  355. def q_sample(self, x_start, time, noise=None):
  356. if noise is None:
  357. noise = torch.randn_like(x_start)
  358. x_start_shape = x_start.shape
  359. sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, time,
  360. x_start_shape)
  361. sqrt_one_minus_alphas_cumprod_t = extract(
  362. self.sqrt_one_minus_alphas_cumprod, time, x_start_shape)
  363. return sqrt_alphas_cumprod_t * x_start + \
  364. sqrt_one_minus_alphas_cumprod_t * noise
  365. def predict(self,
  366. x: Tuple[Tensor],
  367. batch_data_samples: SampleList,
  368. rescale: bool = False) -> InstanceList:
  369. """Perform forward propagation of the detection head and predict
  370. detection results on the features of the upstream network.
  371. Args:
  372. x (tuple[Tensor]): Multi-level features from the
  373. upstream network, each is a 4D-tensor.
  374. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  375. Samples. It usually includes information such as
  376. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  377. rescale (bool, optional): Whether to rescale the results.
  378. Defaults to False.
  379. Returns:
  380. list[obj:`InstanceData`]: Detection results of each image
  381. after the post process.
  382. """
  383. # hard-setting seed to keep results same (if necessary)
  384. # seed = 0
  385. # random.seed(seed)
  386. # torch.manual_seed(seed)
  387. # torch.cuda.manual_seed_all(seed)
  388. device = x[-1].device
  389. batch_img_metas = [
  390. data_samples.metainfo for data_samples in batch_data_samples
  391. ]
  392. (time_pairs, batch_noise_bboxes, batch_noise_bboxes_raw,
  393. batch_image_size) = self.prepare_testing_targets(
  394. batch_img_metas, device)
  395. predictions = self.predict_by_feat(
  396. x,
  397. time_pairs=time_pairs,
  398. batch_noise_bboxes=batch_noise_bboxes,
  399. batch_noise_bboxes_raw=batch_noise_bboxes_raw,
  400. batch_image_size=batch_image_size,
  401. device=device,
  402. batch_img_metas=batch_img_metas)
  403. return predictions
  404. def predict_by_feat(self,
  405. x,
  406. time_pairs,
  407. batch_noise_bboxes,
  408. batch_noise_bboxes_raw,
  409. batch_image_size,
  410. device,
  411. batch_img_metas=None,
  412. cfg=None,
  413. rescale=True):
  414. batch_size = len(batch_img_metas)
  415. cfg = self.test_cfg if cfg is None else cfg
  416. cfg = copy.deepcopy(cfg)
  417. ensemble_score, ensemble_label, ensemble_coord = [], [], []
  418. for time, time_next in time_pairs:
  419. batch_time = torch.full((batch_size, ),
  420. time,
  421. device=device,
  422. dtype=torch.long)
  423. # self_condition = x_start if self.self_condition else None
  424. pred_logits, pred_bboxes = self(x, batch_noise_bboxes, batch_time)
  425. x_start = pred_bboxes[-1]
  426. x_start = x_start / batch_image_size[:, None, :]
  427. x_start = bbox_xyxy_to_cxcywh(x_start)
  428. x_start = (x_start * 2 - 1.) * self.snr_scale
  429. x_start = torch.clamp(
  430. x_start, min=-1 * self.snr_scale, max=self.snr_scale)
  431. pred_noise = self.predict_noise_from_start(batch_noise_bboxes_raw,
  432. batch_time, x_start)
  433. pred_noise_list, x_start_list = [], []
  434. noise_bboxes_list, num_remain_list = [], []
  435. if self.box_renewal: # filter
  436. score_thr = cfg.get('score_thr', 0)
  437. for img_id in range(batch_size):
  438. score_per_image = pred_logits[-1][img_id]
  439. score_per_image = torch.sigmoid(score_per_image)
  440. value, _ = torch.max(score_per_image, -1, keepdim=False)
  441. keep_idx = value > score_thr
  442. num_remain_list.append(torch.sum(keep_idx))
  443. pred_noise_list.append(pred_noise[img_id, keep_idx, :])
  444. x_start_list.append(x_start[img_id, keep_idx, :])
  445. noise_bboxes_list.append(batch_noise_bboxes[img_id,
  446. keep_idx, :])
  447. if time_next < 0:
  448. # Not same as original DiffusionDet
  449. if self.use_ensemble and self.sampling_timesteps > 1:
  450. box_pred_per_image, scores_per_image, labels_per_image = \
  451. self.inference(
  452. box_cls=pred_logits[-1],
  453. box_pred=pred_bboxes[-1],
  454. cfg=cfg,
  455. device=device)
  456. ensemble_score.append(scores_per_image)
  457. ensemble_label.append(labels_per_image)
  458. ensemble_coord.append(box_pred_per_image)
  459. continue
  460. alpha = self.alphas_cumprod[time]
  461. alpha_next = self.alphas_cumprod[time_next]
  462. sigma = self.ddim_sampling_eta * ((1 - alpha / alpha_next) *
  463. (1 - alpha_next) /
  464. (1 - alpha)).sqrt()
  465. c = (1 - alpha_next - sigma**2).sqrt()
  466. batch_noise_bboxes_list = []
  467. batch_noise_bboxes_raw_list = []
  468. for idx in range(batch_size):
  469. pred_noise = pred_noise_list[idx]
  470. x_start = x_start_list[idx]
  471. noise_bboxes = noise_bboxes_list[idx]
  472. num_remain = num_remain_list[idx]
  473. noise = torch.randn_like(noise_bboxes)
  474. noise_bboxes = x_start * alpha_next.sqrt() + \
  475. c * pred_noise + sigma * noise
  476. if self.box_renewal: # filter
  477. # replenish with randn boxes
  478. if num_remain < self.num_proposals:
  479. noise_bboxes = torch.cat(
  480. (noise_bboxes,
  481. torch.randn(
  482. self.num_proposals - num_remain,
  483. 4,
  484. device=device)),
  485. dim=0)
  486. else:
  487. select_mask = [True] * self.num_proposals + \
  488. [False] * (num_remain -
  489. self.num_proposals)
  490. random.shuffle(select_mask)
  491. noise_bboxes = noise_bboxes[select_mask]
  492. # raw noise boxes
  493. batch_noise_bboxes_raw_list.append(noise_bboxes)
  494. # resize to xyxy
  495. noise_bboxes = torch.clamp(
  496. noise_bboxes,
  497. min=-1 * self.snr_scale,
  498. max=self.snr_scale)
  499. noise_bboxes = ((noise_bboxes / self.snr_scale) + 1) / 2
  500. noise_bboxes = bbox_cxcywh_to_xyxy(noise_bboxes)
  501. noise_bboxes = noise_bboxes * batch_image_size[idx]
  502. batch_noise_bboxes_list.append(noise_bboxes)
  503. batch_noise_bboxes = torch.stack(batch_noise_bboxes_list)
  504. batch_noise_bboxes_raw = torch.stack(batch_noise_bboxes_raw_list)
  505. if self.use_ensemble and self.sampling_timesteps > 1:
  506. box_pred_per_image, scores_per_image, labels_per_image = \
  507. self.inference(
  508. box_cls=pred_logits[-1],
  509. box_pred=pred_bboxes[-1],
  510. cfg=cfg,
  511. device=device)
  512. ensemble_score.append(scores_per_image)
  513. ensemble_label.append(labels_per_image)
  514. ensemble_coord.append(box_pred_per_image)
  515. if self.use_ensemble and self.sampling_timesteps > 1:
  516. steps = len(ensemble_score)
  517. results_list = []
  518. for idx in range(batch_size):
  519. ensemble_score_per_img = [
  520. ensemble_score[i][idx] for i in range(steps)
  521. ]
  522. ensemble_label_per_img = [
  523. ensemble_label[i][idx] for i in range(steps)
  524. ]
  525. ensemble_coord_per_img = [
  526. ensemble_coord[i][idx] for i in range(steps)
  527. ]
  528. scores_per_image = torch.cat(ensemble_score_per_img, dim=0)
  529. labels_per_image = torch.cat(ensemble_label_per_img, dim=0)
  530. box_pred_per_image = torch.cat(ensemble_coord_per_img, dim=0)
  531. if self.use_nms:
  532. det_bboxes, keep_idxs = batched_nms(
  533. box_pred_per_image, scores_per_image, labels_per_image,
  534. cfg.nms)
  535. box_pred_per_image = box_pred_per_image[keep_idxs]
  536. labels_per_image = labels_per_image[keep_idxs]
  537. scores_per_image = det_bboxes[:, -1]
  538. results = InstanceData()
  539. results.bboxes = box_pred_per_image
  540. results.scores = scores_per_image
  541. results.labels = labels_per_image
  542. results_list.append(results)
  543. else:
  544. box_cls = pred_logits[-1]
  545. box_pred = pred_bboxes[-1]
  546. results_list = self.inference(box_cls, box_pred, cfg, device)
  547. if rescale:
  548. results_list = self.do_results_post_process(
  549. results_list, cfg, batch_img_metas=batch_img_metas)
  550. return results_list
  551. @staticmethod
  552. def do_results_post_process(results_list, cfg, batch_img_metas=None):
  553. processed_results = []
  554. for results, img_meta in zip(results_list, batch_img_metas):
  555. assert img_meta.get('scale_factor') is not None
  556. scale_factor = [1 / s for s in img_meta['scale_factor']]
  557. results.bboxes = scale_boxes(results.bboxes, scale_factor)
  558. # clip w, h
  559. h, w = img_meta['ori_shape']
  560. results.bboxes[:, 0::2] = results.bboxes[:, 0::2].clamp(
  561. min=0, max=w)
  562. results.bboxes[:, 1::2] = results.bboxes[:, 1::2].clamp(
  563. min=0, max=h)
  564. # filter small size bboxes
  565. if cfg.get('min_bbox_size', 0) >= 0:
  566. w, h = get_box_wh(results.bboxes)
  567. valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
  568. if not valid_mask.all():
  569. results = results[valid_mask]
  570. processed_results.append(results)
  571. return processed_results
  572. def prepare_testing_targets(self, batch_img_metas, device):
  573. # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == timesteps
  574. times = torch.linspace(
  575. -1, self.timesteps - 1, steps=self.sampling_timesteps + 1)
  576. times = list(reversed(times.int().tolist()))
  577. # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
  578. time_pairs = list(zip(times[:-1], times[1:]))
  579. noise_bboxes_list = []
  580. noise_bboxes_raw_list = []
  581. image_size_list = []
  582. for img_meta in batch_img_metas:
  583. h, w = img_meta['img_shape']
  584. image_size = torch.tensor([w, h, w, h],
  585. dtype=torch.float32,
  586. device=device)
  587. noise_bboxes_raw = torch.randn((self.num_proposals, 4),
  588. device=device)
  589. noise_bboxes = torch.clamp(
  590. noise_bboxes_raw, min=-1 * self.snr_scale, max=self.snr_scale)
  591. noise_bboxes = ((noise_bboxes / self.snr_scale) + 1) / 2
  592. noise_bboxes = bbox_cxcywh_to_xyxy(noise_bboxes)
  593. noise_bboxes = noise_bboxes * image_size
  594. noise_bboxes_raw_list.append(noise_bboxes_raw)
  595. noise_bboxes_list.append(noise_bboxes)
  596. image_size_list.append(image_size[None])
  597. batch_noise_bboxes = torch.stack(noise_bboxes_list)
  598. batch_image_size = torch.cat(image_size_list)
  599. batch_noise_bboxes_raw = torch.stack(noise_bboxes_raw_list)
  600. return (time_pairs, batch_noise_bboxes, batch_noise_bboxes_raw,
  601. batch_image_size)
  602. def predict_noise_from_start(self, x_t, t, x0):
  603. results = (extract(
  604. self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
  605. extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
  606. return results
  607. def inference(self, box_cls, box_pred, cfg, device):
  608. """
  609. Args:
  610. box_cls (Tensor): tensor of shape (batch_size, num_proposals, K).
  611. The tensor predicts the classification probability for
  612. each proposal.
  613. box_pred (Tensor): tensors of shape (batch_size, num_proposals, 4).
  614. The tensor predicts 4-vector (x,y,w,h) box
  615. regression values for every proposal
  616. Returns:
  617. results (List[Instances]): a list of #images elements.
  618. """
  619. results = []
  620. if self.use_focal_loss or self.use_fed_loss:
  621. scores = torch.sigmoid(box_cls)
  622. labels = torch.arange(
  623. self.num_classes,
  624. device=device).unsqueeze(0).repeat(self.num_proposals,
  625. 1).flatten(0, 1)
  626. box_pred_list = []
  627. scores_list = []
  628. labels_list = []
  629. for i, (scores_per_image,
  630. box_pred_per_image) in enumerate(zip(scores, box_pred)):
  631. scores_per_image, topk_indices = scores_per_image.flatten(
  632. 0, 1).topk(
  633. self.num_proposals, sorted=False)
  634. labels_per_image = labels[topk_indices]
  635. box_pred_per_image = box_pred_per_image.view(-1, 1, 4).repeat(
  636. 1, self.num_classes, 1).view(-1, 4)
  637. box_pred_per_image = box_pred_per_image[topk_indices]
  638. if self.use_ensemble and self.sampling_timesteps > 1:
  639. box_pred_list.append(box_pred_per_image)
  640. scores_list.append(scores_per_image)
  641. labels_list.append(labels_per_image)
  642. continue
  643. if self.use_nms:
  644. det_bboxes, keep_idxs = batched_nms(
  645. box_pred_per_image, scores_per_image, labels_per_image,
  646. cfg.nms)
  647. box_pred_per_image = box_pred_per_image[keep_idxs]
  648. labels_per_image = labels_per_image[keep_idxs]
  649. # some nms would reweight the score, such as softnms
  650. scores_per_image = det_bboxes[:, -1]
  651. result = InstanceData()
  652. result.bboxes = box_pred_per_image
  653. result.scores = scores_per_image
  654. result.labels = labels_per_image
  655. results.append(result)
  656. else:
  657. # For each box we assign the best class or the second
  658. # best if the best on is `no_object`.
  659. scores, labels = F.softmax(box_cls, dim=-1)[:, :, :-1].max(-1)
  660. for i, (scores_per_image, labels_per_image,
  661. box_pred_per_image) in enumerate(
  662. zip(scores, labels, box_pred)):
  663. if self.use_ensemble and self.sampling_timesteps > 1:
  664. return box_pred_per_image, scores_per_image, \
  665. labels_per_image
  666. if self.use_nms:
  667. det_bboxes, keep_idxs = batched_nms(
  668. box_pred_per_image, scores_per_image, labels_per_image,
  669. cfg.nms)
  670. box_pred_per_image = box_pred_per_image[keep_idxs]
  671. labels_per_image = labels_per_image[keep_idxs]
  672. # some nms would reweight the score, such as softnms
  673. scores_per_image = det_bboxes[:, -1]
  674. result = InstanceData()
  675. result.bboxes = box_pred_per_image
  676. result.scores = scores_per_image
  677. result.labels = labels_per_image
  678. results.append(result)
  679. if self.use_ensemble and self.sampling_timesteps > 1:
  680. return box_pred_list, scores_list, labels_list
  681. else:
  682. return results
  683. @MODELS.register_module()
  684. class SingleDiffusionDetHead(nn.Module):
  685. def __init__(
  686. self,
  687. num_classes=80,
  688. feat_channels=256,
  689. dim_feedforward=2048,
  690. num_cls_convs=1,
  691. num_reg_convs=3,
  692. num_heads=8,
  693. dropout=0.0,
  694. pooler_resolution=7,
  695. scale_clamp=_DEFAULT_SCALE_CLAMP,
  696. bbox_weights=(2.0, 2.0, 1.0, 1.0),
  697. use_focal_loss=True,
  698. use_fed_loss=False,
  699. act_cfg=dict(type='ReLU', inplace=True),
  700. dynamic_conv=dict(dynamic_dim=64, dynamic_num=2)
  701. ) -> None:
  702. super().__init__()
  703. self.feat_channels = feat_channels
  704. # Dynamic
  705. self.self_attn = nn.MultiheadAttention(
  706. feat_channels, num_heads, dropout=dropout)
  707. self.inst_interact = DynamicConv(
  708. feat_channels=feat_channels,
  709. pooler_resolution=pooler_resolution,
  710. dynamic_dim=dynamic_conv['dynamic_dim'],
  711. dynamic_num=dynamic_conv['dynamic_num'])
  712. self.linear1 = nn.Linear(feat_channels, dim_feedforward)
  713. self.dropout = nn.Dropout(dropout)
  714. self.linear2 = nn.Linear(dim_feedforward, feat_channels)
  715. self.norm1 = nn.LayerNorm(feat_channels)
  716. self.norm2 = nn.LayerNorm(feat_channels)
  717. self.norm3 = nn.LayerNorm(feat_channels)
  718. self.dropout1 = nn.Dropout(dropout)
  719. self.dropout2 = nn.Dropout(dropout)
  720. self.dropout3 = nn.Dropout(dropout)
  721. self.activation = build_activation_layer(act_cfg)
  722. # block time mlp
  723. self.block_time_mlp = nn.Sequential(
  724. nn.SiLU(), nn.Linear(feat_channels * 4, feat_channels * 2))
  725. # cls.
  726. cls_module = list()
  727. for _ in range(num_cls_convs):
  728. cls_module.append(nn.Linear(feat_channels, feat_channels, False))
  729. cls_module.append(nn.LayerNorm(feat_channels))
  730. cls_module.append(nn.ReLU(inplace=True))
  731. self.cls_module = nn.ModuleList(cls_module)
  732. # reg.
  733. reg_module = list()
  734. for _ in range(num_reg_convs):
  735. reg_module.append(nn.Linear(feat_channels, feat_channels, False))
  736. reg_module.append(nn.LayerNorm(feat_channels))
  737. reg_module.append(nn.ReLU(inplace=True))
  738. self.reg_module = nn.ModuleList(reg_module)
  739. # pred.
  740. self.use_focal_loss = use_focal_loss
  741. self.use_fed_loss = use_fed_loss
  742. if self.use_focal_loss or self.use_fed_loss:
  743. self.class_logits = nn.Linear(feat_channels, num_classes)
  744. else:
  745. self.class_logits = nn.Linear(feat_channels, num_classes + 1)
  746. self.bboxes_delta = nn.Linear(feat_channels, 4)
  747. self.scale_clamp = scale_clamp
  748. self.bbox_weights = bbox_weights
  749. def forward(self, features, bboxes, pro_features, pooler, time_emb):
  750. """
  751. :param bboxes: (N, num_boxes, 4)
  752. :param pro_features: (N, num_boxes, feat_channels)
  753. """
  754. N, num_boxes = bboxes.shape[:2]
  755. # roi_feature.
  756. proposal_boxes = list()
  757. for b in range(N):
  758. proposal_boxes.append(bboxes[b])
  759. rois = bbox2roi(proposal_boxes)
  760. roi_features = pooler(features, rois)
  761. if pro_features is None:
  762. pro_features = roi_features.view(N, num_boxes, self.feat_channels,
  763. -1).mean(-1)
  764. roi_features = roi_features.view(N * num_boxes, self.feat_channels,
  765. -1).permute(2, 0, 1)
  766. # self_att.
  767. pro_features = pro_features.view(N, num_boxes,
  768. self.feat_channels).permute(1, 0, 2)
  769. pro_features2 = self.self_attn(
  770. pro_features, pro_features, value=pro_features)[0]
  771. pro_features = pro_features + self.dropout1(pro_features2)
  772. pro_features = self.norm1(pro_features)
  773. # inst_interact.
  774. pro_features = pro_features.view(
  775. num_boxes, N,
  776. self.feat_channels).permute(1, 0,
  777. 2).reshape(1, N * num_boxes,
  778. self.feat_channels)
  779. pro_features2 = self.inst_interact(pro_features, roi_features)
  780. pro_features = pro_features + self.dropout2(pro_features2)
  781. obj_features = self.norm2(pro_features)
  782. # obj_feature.
  783. obj_features2 = self.linear2(
  784. self.dropout(self.activation(self.linear1(obj_features))))
  785. obj_features = obj_features + self.dropout3(obj_features2)
  786. obj_features = self.norm3(obj_features)
  787. fc_feature = obj_features.transpose(0, 1).reshape(N * num_boxes, -1)
  788. scale_shift = self.block_time_mlp(time_emb)
  789. scale_shift = torch.repeat_interleave(scale_shift, num_boxes, dim=0)
  790. scale, shift = scale_shift.chunk(2, dim=1)
  791. fc_feature = fc_feature * (scale + 1) + shift
  792. cls_feature = fc_feature.clone()
  793. reg_feature = fc_feature.clone()
  794. for cls_layer in self.cls_module:
  795. cls_feature = cls_layer(cls_feature)
  796. for reg_layer in self.reg_module:
  797. reg_feature = reg_layer(reg_feature)
  798. class_logits = self.class_logits(cls_feature)
  799. bboxes_deltas = self.bboxes_delta(reg_feature)
  800. pred_bboxes = self.apply_deltas(bboxes_deltas, bboxes.view(-1, 4))
  801. return (class_logits.view(N, num_boxes,
  802. -1), pred_bboxes.view(N, num_boxes,
  803. -1), obj_features)
  804. def apply_deltas(self, deltas, boxes):
  805. """Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`.
  806. Args:
  807. deltas (Tensor): transformation deltas of shape (N, k*4),
  808. where k >= 1. deltas[i] represents k potentially
  809. different class-specific box transformations for
  810. the single box boxes[i].
  811. boxes (Tensor): boxes to transform, of shape (N, 4)
  812. """
  813. boxes = boxes.to(deltas.dtype)
  814. widths = boxes[:, 2] - boxes[:, 0]
  815. heights = boxes[:, 3] - boxes[:, 1]
  816. ctr_x = boxes[:, 0] + 0.5 * widths
  817. ctr_y = boxes[:, 1] + 0.5 * heights
  818. wx, wy, ww, wh = self.bbox_weights
  819. dx = deltas[:, 0::4] / wx
  820. dy = deltas[:, 1::4] / wy
  821. dw = deltas[:, 2::4] / ww
  822. dh = deltas[:, 3::4] / wh
  823. # Prevent sending too large values into torch.exp()
  824. dw = torch.clamp(dw, max=self.scale_clamp)
  825. dh = torch.clamp(dh, max=self.scale_clamp)
  826. pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
  827. pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
  828. pred_w = torch.exp(dw) * widths[:, None]
  829. pred_h = torch.exp(dh) * heights[:, None]
  830. pred_boxes = torch.zeros_like(deltas)
  831. pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w # x1
  832. pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h # y1
  833. pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w # x2
  834. pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h # y2
  835. return pred_boxes
  836. class DynamicConv(nn.Module):
  837. def __init__(self,
  838. feat_channels: int,
  839. dynamic_dim: int = 64,
  840. dynamic_num: int = 2,
  841. pooler_resolution: int = 7) -> None:
  842. super().__init__()
  843. self.feat_channels = feat_channels
  844. self.dynamic_dim = dynamic_dim
  845. self.dynamic_num = dynamic_num
  846. self.num_params = self.feat_channels * self.dynamic_dim
  847. self.dynamic_layer = nn.Linear(self.feat_channels,
  848. self.dynamic_num * self.num_params)
  849. self.norm1 = nn.LayerNorm(self.dynamic_dim)
  850. self.norm2 = nn.LayerNorm(self.feat_channels)
  851. self.activation = nn.ReLU(inplace=True)
  852. num_output = self.feat_channels * pooler_resolution**2
  853. self.out_layer = nn.Linear(num_output, self.feat_channels)
  854. self.norm3 = nn.LayerNorm(self.feat_channels)
  855. def forward(self, pro_features: Tensor, roi_features: Tensor) -> Tensor:
  856. """Forward function.
  857. Args:
  858. pro_features: (1, N * num_boxes, self.feat_channels)
  859. roi_features: (49, N * num_boxes, self.feat_channels)
  860. Returns:
  861. """
  862. features = roi_features.permute(1, 0, 2)
  863. parameters = self.dynamic_layer(pro_features).permute(1, 0, 2)
  864. param1 = parameters[:, :, :self.num_params].view(
  865. -1, self.feat_channels, self.dynamic_dim)
  866. param2 = parameters[:, :,
  867. self.num_params:].view(-1, self.dynamic_dim,
  868. self.feat_channels)
  869. features = torch.bmm(features, param1)
  870. features = self.norm1(features)
  871. features = self.activation(features)
  872. features = torch.bmm(features, param2)
  873. features = self.norm2(features)
  874. features = self.activation(features)
  875. features = features.flatten(1)
  876. features = self.out_layer(features)
  877. features = self.norm3(features)
  878. features = self.activation(features)
  879. return features