12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034 |
- # Copyright (c) OpenMMLab. All rights reserved.
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
- # Modified from https://github.com/ShoufaChen/DiffusionDet/blob/main/diffusiondet/detector.py # noqa
- # Modified from https://github.com/ShoufaChen/DiffusionDet/blob/main/diffusiondet/head.py # noqa
- # This work is licensed under the CC-BY-NC 4.0 License.
- # Users should be careful about adopting these features in any commercial matters. # noqa
- # For more details, please refer to https://github.com/ShoufaChen/DiffusionDet/blob/main/LICENSE # noqa
- import copy
- import math
- import random
- import warnings
- from typing import Tuple
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from mmcv.cnn import build_activation_layer
- from mmcv.ops import batched_nms
- from mmengine.structures import InstanceData
- from torch import Tensor
- from mmdet.registry import MODELS, TASK_UTILS
- from mmdet.structures import SampleList
- from mmdet.structures.bbox import (bbox2roi, bbox_cxcywh_to_xyxy,
- bbox_xyxy_to_cxcywh, get_box_wh,
- scale_boxes)
- from mmdet.utils import InstanceList
- _DEFAULT_SCALE_CLAMP = math.log(100000.0 / 16)
- def cosine_beta_schedule(timesteps, s=0.008):
- """Cosine schedule as proposed in
- https://openreview.net/forum?id=-NEXDKk8gZ."""
- steps = timesteps + 1
- x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
- alphas_cumprod = torch.cos(
- ((x / timesteps) + s) / (1 + s) * math.pi * 0.5)**2
- alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
- betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
- return torch.clip(betas, 0, 0.999)
- def extract(a, t, x_shape):
- """extract the appropriate t index for a batch of indices."""
- batch_size = t.shape[0]
- out = a.gather(-1, t)
- return out.reshape(batch_size, *((1, ) * (len(x_shape) - 1)))
- class SinusoidalPositionEmbeddings(nn.Module):
- def __init__(self, dim):
- super().__init__()
- self.dim = dim
- def forward(self, time):
- device = time.device
- half_dim = self.dim // 2
- embeddings = math.log(10000) / (half_dim - 1)
- embeddings = torch.exp(
- torch.arange(half_dim, device=device) * -embeddings)
- embeddings = time[:, None] * embeddings[None, :]
- embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
- return embeddings
- @MODELS.register_module()
- class DynamicDiffusionDetHead(nn.Module):
- def __init__(self,
- num_classes=80,
- feat_channels=256,
- num_proposals=500,
- num_heads=6,
- prior_prob=0.01,
- snr_scale=2.0,
- timesteps=1000,
- sampling_timesteps=1,
- self_condition=False,
- box_renewal=True,
- use_ensemble=True,
- deep_supervision=True,
- ddim_sampling_eta=1.0,
- criterion=dict(
- type='DiffusionDetCriterion',
- num_classes=80,
- assigner=dict(
- type='DiffusionDetMatcher',
- match_costs=[
- dict(
- type='FocalLossCost',
- alpha=2.0,
- gamma=0.25,
- weight=2.0),
- dict(
- type='BBoxL1Cost',
- weight=5.0,
- box_format='xyxy'),
- dict(type='IoUCost', iou_mode='giou', weight=2.0)
- ],
- center_radius=2.5,
- candidate_topk=5),
- ),
- single_head=dict(
- type='DiffusionDetHead',
- num_cls_convs=1,
- num_reg_convs=3,
- dim_feedforward=2048,
- num_heads=8,
- dropout=0.0,
- act_cfg=dict(type='ReLU'),
- dynamic_conv=dict(dynamic_dim=64, dynamic_num=2)),
- roi_extractor=dict(
- type='SingleRoIExtractor',
- roi_layer=dict(
- type='RoIAlign', output_size=7, sampling_ratio=2),
- out_channels=256,
- featmap_strides=[4, 8, 16, 32]),
- test_cfg=None,
- **kwargs) -> None:
- super().__init__()
- self.roi_extractor = MODELS.build(roi_extractor)
- self.num_classes = num_classes
- self.num_classes = num_classes
- self.feat_channels = feat_channels
- self.num_proposals = num_proposals
- self.num_heads = num_heads
- # Build Diffusion
- assert isinstance(timesteps, int), 'The type of `timesteps` should ' \
- f'be int but got {type(timesteps)}'
- assert sampling_timesteps <= timesteps
- self.timesteps = timesteps
- self.sampling_timesteps = sampling_timesteps
- self.snr_scale = snr_scale
- self.ddim_sampling = self.sampling_timesteps < self.timesteps
- self.ddim_sampling_eta = ddim_sampling_eta
- self.self_condition = self_condition
- self.box_renewal = box_renewal
- self.use_ensemble = use_ensemble
- self._build_diffusion()
- # Build assigner
- assert criterion.get('assigner', None) is not None
- assigner = TASK_UTILS.build(criterion.get('assigner'))
- # Init parameters.
- self.use_focal_loss = assigner.use_focal_loss
- self.use_fed_loss = assigner.use_fed_loss
- # build criterion
- criterion.update(deep_supervision=deep_supervision)
- self.criterion = TASK_UTILS.build(criterion)
- # Build Dynamic Head.
- single_head_ = single_head.copy()
- single_head_num_classes = single_head_.get('num_classes', None)
- if single_head_num_classes is None:
- single_head_.update(num_classes=num_classes)
- else:
- if single_head_num_classes != num_classes:
- warnings.warn(
- 'The `num_classes` of `DynamicDiffusionDetHead` and '
- '`SingleDiffusionDetHead` should be same, changing '
- f'`single_head.num_classes` to {num_classes}')
- single_head_.update(num_classes=num_classes)
- single_head_feat_channels = single_head_.get('feat_channels', None)
- if single_head_feat_channels is None:
- single_head_.update(feat_channels=feat_channels)
- else:
- if single_head_feat_channels != feat_channels:
- warnings.warn(
- 'The `feat_channels` of `DynamicDiffusionDetHead` and '
- '`SingleDiffusionDetHead` should be same, changing '
- f'`single_head.feat_channels` to {feat_channels}')
- single_head_.update(feat_channels=feat_channels)
- default_pooler_resolution = roi_extractor['roi_layer'].get(
- 'output_size')
- assert default_pooler_resolution is not None
- single_head_pooler_resolution = single_head_.get('pooler_resolution')
- if single_head_pooler_resolution is None:
- single_head_.update(pooler_resolution=default_pooler_resolution)
- else:
- if single_head_pooler_resolution != default_pooler_resolution:
- warnings.warn(
- 'The `pooler_resolution` of `DynamicDiffusionDetHead` '
- 'and `SingleDiffusionDetHead` should be same, changing '
- f'`single_head.pooler_resolution` to {num_classes}')
- single_head_.update(
- pooler_resolution=default_pooler_resolution)
- single_head_.update(
- use_focal_loss=self.use_focal_loss, use_fed_loss=self.use_fed_loss)
- single_head_module = MODELS.build(single_head_)
- self.num_heads = num_heads
- self.head_series = nn.ModuleList(
- [copy.deepcopy(single_head_module) for _ in range(num_heads)])
- self.deep_supervision = deep_supervision
- # Gaussian random feature embedding layer for time
- time_dim = feat_channels * 4
- self.time_mlp = nn.Sequential(
- SinusoidalPositionEmbeddings(feat_channels),
- nn.Linear(feat_channels, time_dim), nn.GELU(),
- nn.Linear(time_dim, time_dim))
- self.prior_prob = prior_prob
- self.test_cfg = test_cfg
- self.use_nms = self.test_cfg.get('use_nms', True)
- self._init_weights()
- def _init_weights(self):
- # init all parameters.
- bias_value = -math.log((1 - self.prior_prob) / self.prior_prob)
- for p in self.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform_(p)
- # initialize the bias for focal loss and fed loss.
- if self.use_focal_loss or self.use_fed_loss:
- if p.shape[-1] == self.num_classes or \
- p.shape[-1] == self.num_classes + 1:
- nn.init.constant_(p, bias_value)
- def _build_diffusion(self):
- betas = cosine_beta_schedule(self.timesteps)
- alphas = 1. - betas
- alphas_cumprod = torch.cumprod(alphas, dim=0)
- alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)
- self.register_buffer('betas', betas)
- self.register_buffer('alphas_cumprod', alphas_cumprod)
- self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
- self.register_buffer('sqrt_one_minus_alphas_cumprod',
- torch.sqrt(1. - alphas_cumprod))
- self.register_buffer('log_one_minus_alphas_cumprod',
- torch.log(1. - alphas_cumprod))
- self.register_buffer('sqrt_recip_alphas_cumprod',
- torch.sqrt(1. / alphas_cumprod))
- self.register_buffer('sqrt_recipm1_alphas_cumprod',
- torch.sqrt(1. / alphas_cumprod - 1))
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- # equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- posterior_variance = betas * (1. - alphas_cumprod_prev) / (
- 1. - alphas_cumprod)
- self.register_buffer('posterior_variance', posterior_variance)
- # log calculation clipped because the posterior variance is 0 at
- # the beginning of the diffusion chain
- self.register_buffer('posterior_log_variance_clipped',
- torch.log(posterior_variance.clamp(min=1e-20)))
- self.register_buffer(
- 'posterior_mean_coef1',
- betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
- self.register_buffer('posterior_mean_coef2',
- (1. - alphas_cumprod_prev) * torch.sqrt(alphas) /
- (1. - alphas_cumprod))
- def forward(self, features, init_bboxes, init_t, init_features=None):
- time = self.time_mlp(init_t, )
- inter_class_logits = []
- inter_pred_bboxes = []
- bs = len(features[0])
- bboxes = init_bboxes
- if init_features is not None:
- init_features = init_features[None].repeat(1, bs, 1)
- proposal_features = init_features.clone()
- else:
- proposal_features = None
- for head_idx, single_head in enumerate(self.head_series):
- class_logits, pred_bboxes, proposal_features = single_head(
- features, bboxes, proposal_features, self.roi_extractor, time)
- if self.deep_supervision:
- inter_class_logits.append(class_logits)
- inter_pred_bboxes.append(pred_bboxes)
- bboxes = pred_bboxes.detach()
- if self.deep_supervision:
- return torch.stack(inter_class_logits), torch.stack(
- inter_pred_bboxes)
- else:
- return class_logits[None, ...], pred_bboxes[None, ...]
- def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict:
- """Perform forward propagation and loss calculation of the detection
- head on the features of the upstream network.
- Args:
- x (tuple[Tensor]): Features from the upstream network, each is
- a 4D-tensor.
- batch_data_samples (List[:obj:`DetDataSample`]): The Data
- Samples. It usually includes information such as
- `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
- Returns:
- dict: A dictionary of loss components.
- """
- prepare_outputs = self.prepare_training_targets(batch_data_samples)
- (batch_gt_instances, batch_pred_instances, batch_gt_instances_ignore,
- batch_img_metas) = prepare_outputs
- batch_diff_bboxes = torch.stack([
- pred_instances.diff_bboxes_abs
- for pred_instances in batch_pred_instances
- ])
- batch_time = torch.stack(
- [pred_instances.time for pred_instances in batch_pred_instances])
- pred_logits, pred_bboxes = self(x, batch_diff_bboxes, batch_time)
- output = {
- 'pred_logits': pred_logits[-1],
- 'pred_boxes': pred_bboxes[-1]
- }
- if self.deep_supervision:
- output['aux_outputs'] = [{
- 'pred_logits': a,
- 'pred_boxes': b
- } for a, b in zip(pred_logits[:-1], pred_bboxes[:-1])]
- losses = self.criterion(output, batch_gt_instances, batch_img_metas)
- return losses
- def prepare_training_targets(self, batch_data_samples):
- # hard-setting seed to keep results same (if necessary)
- # random.seed(0)
- # torch.manual_seed(0)
- # torch.cuda.manual_seed_all(0)
- # torch.backends.cudnn.deterministic = True
- # torch.backends.cudnn.benchmark = False
- batch_gt_instances = []
- batch_pred_instances = []
- batch_gt_instances_ignore = []
- batch_img_metas = []
- for data_sample in batch_data_samples:
- img_meta = data_sample.metainfo
- gt_instances = data_sample.gt_instances
- gt_bboxes = gt_instances.bboxes
- h, w = img_meta['img_shape']
- image_size = gt_bboxes.new_tensor([w, h, w, h])
- norm_gt_bboxes = gt_bboxes / image_size
- norm_gt_bboxes_cxcywh = bbox_xyxy_to_cxcywh(norm_gt_bboxes)
- pred_instances = self.prepare_diffusion(norm_gt_bboxes_cxcywh,
- image_size)
- gt_instances.set_metainfo(dict(image_size=image_size))
- gt_instances.norm_bboxes_cxcywh = norm_gt_bboxes_cxcywh
- batch_gt_instances.append(gt_instances)
- batch_pred_instances.append(pred_instances)
- batch_img_metas.append(data_sample.metainfo)
- if 'ignored_instances' in data_sample:
- batch_gt_instances_ignore.append(data_sample.ignored_instances)
- else:
- batch_gt_instances_ignore.append(None)
- return (batch_gt_instances, batch_pred_instances,
- batch_gt_instances_ignore, batch_img_metas)
- def prepare_diffusion(self, gt_boxes, image_size):
- device = gt_boxes.device
- time = torch.randint(
- 0, self.timesteps, (1, ), dtype=torch.long, device=device)
- noise = torch.randn(self.num_proposals, 4, device=device)
- num_gt = gt_boxes.shape[0]
- if num_gt < self.num_proposals:
- # 3 * sigma = 1/2 --> sigma: 1/6
- box_placeholder = torch.randn(
- self.num_proposals - num_gt, 4, device=device) / 6. + 0.5
- box_placeholder[:, 2:] = torch.clip(
- box_placeholder[:, 2:], min=1e-4)
- x_start = torch.cat((gt_boxes, box_placeholder), dim=0)
- else:
- select_mask = [True] * self.num_proposals + \
- [False] * (num_gt - self.num_proposals)
- random.shuffle(select_mask)
- x_start = gt_boxes[select_mask]
- x_start = (x_start * 2. - 1.) * self.snr_scale
- # noise sample
- x = self.q_sample(x_start=x_start, time=time, noise=noise)
- x = torch.clamp(x, min=-1 * self.snr_scale, max=self.snr_scale)
- x = ((x / self.snr_scale) + 1) / 2.
- diff_bboxes = bbox_cxcywh_to_xyxy(x)
- # convert to abs bboxes
- diff_bboxes_abs = diff_bboxes * image_size
- metainfo = dict(time=time.squeeze(-1))
- pred_instances = InstanceData(metainfo=metainfo)
- pred_instances.diff_bboxes = diff_bboxes
- pred_instances.diff_bboxes_abs = diff_bboxes_abs
- pred_instances.noise = noise
- return pred_instances
- # forward diffusion
- def q_sample(self, x_start, time, noise=None):
- if noise is None:
- noise = torch.randn_like(x_start)
- x_start_shape = x_start.shape
- sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, time,
- x_start_shape)
- sqrt_one_minus_alphas_cumprod_t = extract(
- self.sqrt_one_minus_alphas_cumprod, time, x_start_shape)
- return sqrt_alphas_cumprod_t * x_start + \
- sqrt_one_minus_alphas_cumprod_t * noise
- def predict(self,
- x: Tuple[Tensor],
- batch_data_samples: SampleList,
- rescale: bool = False) -> InstanceList:
- """Perform forward propagation of the detection head and predict
- detection results on the features of the upstream network.
- Args:
- x (tuple[Tensor]): Multi-level features from the
- upstream network, each is a 4D-tensor.
- batch_data_samples (List[:obj:`DetDataSample`]): The Data
- Samples. It usually includes information such as
- `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
- rescale (bool, optional): Whether to rescale the results.
- Defaults to False.
- Returns:
- list[obj:`InstanceData`]: Detection results of each image
- after the post process.
- """
- # hard-setting seed to keep results same (if necessary)
- # seed = 0
- # random.seed(seed)
- # torch.manual_seed(seed)
- # torch.cuda.manual_seed_all(seed)
- device = x[-1].device
- batch_img_metas = [
- data_samples.metainfo for data_samples in batch_data_samples
- ]
- (time_pairs, batch_noise_bboxes, batch_noise_bboxes_raw,
- batch_image_size) = self.prepare_testing_targets(
- batch_img_metas, device)
- predictions = self.predict_by_feat(
- x,
- time_pairs=time_pairs,
- batch_noise_bboxes=batch_noise_bboxes,
- batch_noise_bboxes_raw=batch_noise_bboxes_raw,
- batch_image_size=batch_image_size,
- device=device,
- batch_img_metas=batch_img_metas)
- return predictions
- def predict_by_feat(self,
- x,
- time_pairs,
- batch_noise_bboxes,
- batch_noise_bboxes_raw,
- batch_image_size,
- device,
- batch_img_metas=None,
- cfg=None,
- rescale=True):
- batch_size = len(batch_img_metas)
- cfg = self.test_cfg if cfg is None else cfg
- cfg = copy.deepcopy(cfg)
- ensemble_score, ensemble_label, ensemble_coord = [], [], []
- for time, time_next in time_pairs:
- batch_time = torch.full((batch_size, ),
- time,
- device=device,
- dtype=torch.long)
- # self_condition = x_start if self.self_condition else None
- pred_logits, pred_bboxes = self(x, batch_noise_bboxes, batch_time)
- x_start = pred_bboxes[-1]
- x_start = x_start / batch_image_size[:, None, :]
- x_start = bbox_xyxy_to_cxcywh(x_start)
- x_start = (x_start * 2 - 1.) * self.snr_scale
- x_start = torch.clamp(
- x_start, min=-1 * self.snr_scale, max=self.snr_scale)
- pred_noise = self.predict_noise_from_start(batch_noise_bboxes_raw,
- batch_time, x_start)
- pred_noise_list, x_start_list = [], []
- noise_bboxes_list, num_remain_list = [], []
- if self.box_renewal: # filter
- score_thr = cfg.get('score_thr', 0)
- for img_id in range(batch_size):
- score_per_image = pred_logits[-1][img_id]
- score_per_image = torch.sigmoid(score_per_image)
- value, _ = torch.max(score_per_image, -1, keepdim=False)
- keep_idx = value > score_thr
- num_remain_list.append(torch.sum(keep_idx))
- pred_noise_list.append(pred_noise[img_id, keep_idx, :])
- x_start_list.append(x_start[img_id, keep_idx, :])
- noise_bboxes_list.append(batch_noise_bboxes[img_id,
- keep_idx, :])
- if time_next < 0:
- # Not same as original DiffusionDet
- if self.use_ensemble and self.sampling_timesteps > 1:
- box_pred_per_image, scores_per_image, labels_per_image = \
- self.inference(
- box_cls=pred_logits[-1],
- box_pred=pred_bboxes[-1],
- cfg=cfg,
- device=device)
- ensemble_score.append(scores_per_image)
- ensemble_label.append(labels_per_image)
- ensemble_coord.append(box_pred_per_image)
- continue
- alpha = self.alphas_cumprod[time]
- alpha_next = self.alphas_cumprod[time_next]
- sigma = self.ddim_sampling_eta * ((1 - alpha / alpha_next) *
- (1 - alpha_next) /
- (1 - alpha)).sqrt()
- c = (1 - alpha_next - sigma**2).sqrt()
- batch_noise_bboxes_list = []
- batch_noise_bboxes_raw_list = []
- for idx in range(batch_size):
- pred_noise = pred_noise_list[idx]
- x_start = x_start_list[idx]
- noise_bboxes = noise_bboxes_list[idx]
- num_remain = num_remain_list[idx]
- noise = torch.randn_like(noise_bboxes)
- noise_bboxes = x_start * alpha_next.sqrt() + \
- c * pred_noise + sigma * noise
- if self.box_renewal: # filter
- # replenish with randn boxes
- if num_remain < self.num_proposals:
- noise_bboxes = torch.cat(
- (noise_bboxes,
- torch.randn(
- self.num_proposals - num_remain,
- 4,
- device=device)),
- dim=0)
- else:
- select_mask = [True] * self.num_proposals + \
- [False] * (num_remain -
- self.num_proposals)
- random.shuffle(select_mask)
- noise_bboxes = noise_bboxes[select_mask]
- # raw noise boxes
- batch_noise_bboxes_raw_list.append(noise_bboxes)
- # resize to xyxy
- noise_bboxes = torch.clamp(
- noise_bboxes,
- min=-1 * self.snr_scale,
- max=self.snr_scale)
- noise_bboxes = ((noise_bboxes / self.snr_scale) + 1) / 2
- noise_bboxes = bbox_cxcywh_to_xyxy(noise_bboxes)
- noise_bboxes = noise_bboxes * batch_image_size[idx]
- batch_noise_bboxes_list.append(noise_bboxes)
- batch_noise_bboxes = torch.stack(batch_noise_bboxes_list)
- batch_noise_bboxes_raw = torch.stack(batch_noise_bboxes_raw_list)
- if self.use_ensemble and self.sampling_timesteps > 1:
- box_pred_per_image, scores_per_image, labels_per_image = \
- self.inference(
- box_cls=pred_logits[-1],
- box_pred=pred_bboxes[-1],
- cfg=cfg,
- device=device)
- ensemble_score.append(scores_per_image)
- ensemble_label.append(labels_per_image)
- ensemble_coord.append(box_pred_per_image)
- if self.use_ensemble and self.sampling_timesteps > 1:
- steps = len(ensemble_score)
- results_list = []
- for idx in range(batch_size):
- ensemble_score_per_img = [
- ensemble_score[i][idx] for i in range(steps)
- ]
- ensemble_label_per_img = [
- ensemble_label[i][idx] for i in range(steps)
- ]
- ensemble_coord_per_img = [
- ensemble_coord[i][idx] for i in range(steps)
- ]
- scores_per_image = torch.cat(ensemble_score_per_img, dim=0)
- labels_per_image = torch.cat(ensemble_label_per_img, dim=0)
- box_pred_per_image = torch.cat(ensemble_coord_per_img, dim=0)
- if self.use_nms:
- det_bboxes, keep_idxs = batched_nms(
- box_pred_per_image, scores_per_image, labels_per_image,
- cfg.nms)
- box_pred_per_image = box_pred_per_image[keep_idxs]
- labels_per_image = labels_per_image[keep_idxs]
- scores_per_image = det_bboxes[:, -1]
- results = InstanceData()
- results.bboxes = box_pred_per_image
- results.scores = scores_per_image
- results.labels = labels_per_image
- results_list.append(results)
- else:
- box_cls = pred_logits[-1]
- box_pred = pred_bboxes[-1]
- results_list = self.inference(box_cls, box_pred, cfg, device)
- if rescale:
- results_list = self.do_results_post_process(
- results_list, cfg, batch_img_metas=batch_img_metas)
- return results_list
- @staticmethod
- def do_results_post_process(results_list, cfg, batch_img_metas=None):
- processed_results = []
- for results, img_meta in zip(results_list, batch_img_metas):
- assert img_meta.get('scale_factor') is not None
- scale_factor = [1 / s for s in img_meta['scale_factor']]
- results.bboxes = scale_boxes(results.bboxes, scale_factor)
- # clip w, h
- h, w = img_meta['ori_shape']
- results.bboxes[:, 0::2] = results.bboxes[:, 0::2].clamp(
- min=0, max=w)
- results.bboxes[:, 1::2] = results.bboxes[:, 1::2].clamp(
- min=0, max=h)
- # filter small size bboxes
- if cfg.get('min_bbox_size', 0) >= 0:
- w, h = get_box_wh(results.bboxes)
- valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
- if not valid_mask.all():
- results = results[valid_mask]
- processed_results.append(results)
- return processed_results
- def prepare_testing_targets(self, batch_img_metas, device):
- # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == timesteps
- times = torch.linspace(
- -1, self.timesteps - 1, steps=self.sampling_timesteps + 1)
- times = list(reversed(times.int().tolist()))
- # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
- time_pairs = list(zip(times[:-1], times[1:]))
- noise_bboxes_list = []
- noise_bboxes_raw_list = []
- image_size_list = []
- for img_meta in batch_img_metas:
- h, w = img_meta['img_shape']
- image_size = torch.tensor([w, h, w, h],
- dtype=torch.float32,
- device=device)
- noise_bboxes_raw = torch.randn((self.num_proposals, 4),
- device=device)
- noise_bboxes = torch.clamp(
- noise_bboxes_raw, min=-1 * self.snr_scale, max=self.snr_scale)
- noise_bboxes = ((noise_bboxes / self.snr_scale) + 1) / 2
- noise_bboxes = bbox_cxcywh_to_xyxy(noise_bboxes)
- noise_bboxes = noise_bboxes * image_size
- noise_bboxes_raw_list.append(noise_bboxes_raw)
- noise_bboxes_list.append(noise_bboxes)
- image_size_list.append(image_size[None])
- batch_noise_bboxes = torch.stack(noise_bboxes_list)
- batch_image_size = torch.cat(image_size_list)
- batch_noise_bboxes_raw = torch.stack(noise_bboxes_raw_list)
- return (time_pairs, batch_noise_bboxes, batch_noise_bboxes_raw,
- batch_image_size)
- def predict_noise_from_start(self, x_t, t, x0):
- results = (extract(
- self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
- return results
- def inference(self, box_cls, box_pred, cfg, device):
- """
- Args:
- box_cls (Tensor): tensor of shape (batch_size, num_proposals, K).
- The tensor predicts the classification probability for
- each proposal.
- box_pred (Tensor): tensors of shape (batch_size, num_proposals, 4).
- The tensor predicts 4-vector (x,y,w,h) box
- regression values for every proposal
- Returns:
- results (List[Instances]): a list of #images elements.
- """
- results = []
- if self.use_focal_loss or self.use_fed_loss:
- scores = torch.sigmoid(box_cls)
- labels = torch.arange(
- self.num_classes,
- device=device).unsqueeze(0).repeat(self.num_proposals,
- 1).flatten(0, 1)
- box_pred_list = []
- scores_list = []
- labels_list = []
- for i, (scores_per_image,
- box_pred_per_image) in enumerate(zip(scores, box_pred)):
- scores_per_image, topk_indices = scores_per_image.flatten(
- 0, 1).topk(
- self.num_proposals, sorted=False)
- labels_per_image = labels[topk_indices]
- box_pred_per_image = box_pred_per_image.view(-1, 1, 4).repeat(
- 1, self.num_classes, 1).view(-1, 4)
- box_pred_per_image = box_pred_per_image[topk_indices]
- if self.use_ensemble and self.sampling_timesteps > 1:
- box_pred_list.append(box_pred_per_image)
- scores_list.append(scores_per_image)
- labels_list.append(labels_per_image)
- continue
- if self.use_nms:
- det_bboxes, keep_idxs = batched_nms(
- box_pred_per_image, scores_per_image, labels_per_image,
- cfg.nms)
- box_pred_per_image = box_pred_per_image[keep_idxs]
- labels_per_image = labels_per_image[keep_idxs]
- # some nms would reweight the score, such as softnms
- scores_per_image = det_bboxes[:, -1]
- result = InstanceData()
- result.bboxes = box_pred_per_image
- result.scores = scores_per_image
- result.labels = labels_per_image
- results.append(result)
- else:
- # For each box we assign the best class or the second
- # best if the best on is `no_object`.
- scores, labels = F.softmax(box_cls, dim=-1)[:, :, :-1].max(-1)
- for i, (scores_per_image, labels_per_image,
- box_pred_per_image) in enumerate(
- zip(scores, labels, box_pred)):
- if self.use_ensemble and self.sampling_timesteps > 1:
- return box_pred_per_image, scores_per_image, \
- labels_per_image
- if self.use_nms:
- det_bboxes, keep_idxs = batched_nms(
- box_pred_per_image, scores_per_image, labels_per_image,
- cfg.nms)
- box_pred_per_image = box_pred_per_image[keep_idxs]
- labels_per_image = labels_per_image[keep_idxs]
- # some nms would reweight the score, such as softnms
- scores_per_image = det_bboxes[:, -1]
- result = InstanceData()
- result.bboxes = box_pred_per_image
- result.scores = scores_per_image
- result.labels = labels_per_image
- results.append(result)
- if self.use_ensemble and self.sampling_timesteps > 1:
- return box_pred_list, scores_list, labels_list
- else:
- return results
- @MODELS.register_module()
- class SingleDiffusionDetHead(nn.Module):
- def __init__(
- self,
- num_classes=80,
- feat_channels=256,
- dim_feedforward=2048,
- num_cls_convs=1,
- num_reg_convs=3,
- num_heads=8,
- dropout=0.0,
- pooler_resolution=7,
- scale_clamp=_DEFAULT_SCALE_CLAMP,
- bbox_weights=(2.0, 2.0, 1.0, 1.0),
- use_focal_loss=True,
- use_fed_loss=False,
- act_cfg=dict(type='ReLU', inplace=True),
- dynamic_conv=dict(dynamic_dim=64, dynamic_num=2)
- ) -> None:
- super().__init__()
- self.feat_channels = feat_channels
- # Dynamic
- self.self_attn = nn.MultiheadAttention(
- feat_channels, num_heads, dropout=dropout)
- self.inst_interact = DynamicConv(
- feat_channels=feat_channels,
- pooler_resolution=pooler_resolution,
- dynamic_dim=dynamic_conv['dynamic_dim'],
- dynamic_num=dynamic_conv['dynamic_num'])
- self.linear1 = nn.Linear(feat_channels, dim_feedforward)
- self.dropout = nn.Dropout(dropout)
- self.linear2 = nn.Linear(dim_feedforward, feat_channels)
- self.norm1 = nn.LayerNorm(feat_channels)
- self.norm2 = nn.LayerNorm(feat_channels)
- self.norm3 = nn.LayerNorm(feat_channels)
- self.dropout1 = nn.Dropout(dropout)
- self.dropout2 = nn.Dropout(dropout)
- self.dropout3 = nn.Dropout(dropout)
- self.activation = build_activation_layer(act_cfg)
- # block time mlp
- self.block_time_mlp = nn.Sequential(
- nn.SiLU(), nn.Linear(feat_channels * 4, feat_channels * 2))
- # cls.
- cls_module = list()
- for _ in range(num_cls_convs):
- cls_module.append(nn.Linear(feat_channels, feat_channels, False))
- cls_module.append(nn.LayerNorm(feat_channels))
- cls_module.append(nn.ReLU(inplace=True))
- self.cls_module = nn.ModuleList(cls_module)
- # reg.
- reg_module = list()
- for _ in range(num_reg_convs):
- reg_module.append(nn.Linear(feat_channels, feat_channels, False))
- reg_module.append(nn.LayerNorm(feat_channels))
- reg_module.append(nn.ReLU(inplace=True))
- self.reg_module = nn.ModuleList(reg_module)
- # pred.
- self.use_focal_loss = use_focal_loss
- self.use_fed_loss = use_fed_loss
- if self.use_focal_loss or self.use_fed_loss:
- self.class_logits = nn.Linear(feat_channels, num_classes)
- else:
- self.class_logits = nn.Linear(feat_channels, num_classes + 1)
- self.bboxes_delta = nn.Linear(feat_channels, 4)
- self.scale_clamp = scale_clamp
- self.bbox_weights = bbox_weights
- def forward(self, features, bboxes, pro_features, pooler, time_emb):
- """
- :param bboxes: (N, num_boxes, 4)
- :param pro_features: (N, num_boxes, feat_channels)
- """
- N, num_boxes = bboxes.shape[:2]
- # roi_feature.
- proposal_boxes = list()
- for b in range(N):
- proposal_boxes.append(bboxes[b])
- rois = bbox2roi(proposal_boxes)
- roi_features = pooler(features, rois)
- if pro_features is None:
- pro_features = roi_features.view(N, num_boxes, self.feat_channels,
- -1).mean(-1)
- roi_features = roi_features.view(N * num_boxes, self.feat_channels,
- -1).permute(2, 0, 1)
- # self_att.
- pro_features = pro_features.view(N, num_boxes,
- self.feat_channels).permute(1, 0, 2)
- pro_features2 = self.self_attn(
- pro_features, pro_features, value=pro_features)[0]
- pro_features = pro_features + self.dropout1(pro_features2)
- pro_features = self.norm1(pro_features)
- # inst_interact.
- pro_features = pro_features.view(
- num_boxes, N,
- self.feat_channels).permute(1, 0,
- 2).reshape(1, N * num_boxes,
- self.feat_channels)
- pro_features2 = self.inst_interact(pro_features, roi_features)
- pro_features = pro_features + self.dropout2(pro_features2)
- obj_features = self.norm2(pro_features)
- # obj_feature.
- obj_features2 = self.linear2(
- self.dropout(self.activation(self.linear1(obj_features))))
- obj_features = obj_features + self.dropout3(obj_features2)
- obj_features = self.norm3(obj_features)
- fc_feature = obj_features.transpose(0, 1).reshape(N * num_boxes, -1)
- scale_shift = self.block_time_mlp(time_emb)
- scale_shift = torch.repeat_interleave(scale_shift, num_boxes, dim=0)
- scale, shift = scale_shift.chunk(2, dim=1)
- fc_feature = fc_feature * (scale + 1) + shift
- cls_feature = fc_feature.clone()
- reg_feature = fc_feature.clone()
- for cls_layer in self.cls_module:
- cls_feature = cls_layer(cls_feature)
- for reg_layer in self.reg_module:
- reg_feature = reg_layer(reg_feature)
- class_logits = self.class_logits(cls_feature)
- bboxes_deltas = self.bboxes_delta(reg_feature)
- pred_bboxes = self.apply_deltas(bboxes_deltas, bboxes.view(-1, 4))
- return (class_logits.view(N, num_boxes,
- -1), pred_bboxes.view(N, num_boxes,
- -1), obj_features)
- def apply_deltas(self, deltas, boxes):
- """Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`.
- Args:
- deltas (Tensor): transformation deltas of shape (N, k*4),
- where k >= 1. deltas[i] represents k potentially
- different class-specific box transformations for
- the single box boxes[i].
- boxes (Tensor): boxes to transform, of shape (N, 4)
- """
- boxes = boxes.to(deltas.dtype)
- widths = boxes[:, 2] - boxes[:, 0]
- heights = boxes[:, 3] - boxes[:, 1]
- ctr_x = boxes[:, 0] + 0.5 * widths
- ctr_y = boxes[:, 1] + 0.5 * heights
- wx, wy, ww, wh = self.bbox_weights
- dx = deltas[:, 0::4] / wx
- dy = deltas[:, 1::4] / wy
- dw = deltas[:, 2::4] / ww
- dh = deltas[:, 3::4] / wh
- # Prevent sending too large values into torch.exp()
- dw = torch.clamp(dw, max=self.scale_clamp)
- dh = torch.clamp(dh, max=self.scale_clamp)
- pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
- pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
- pred_w = torch.exp(dw) * widths[:, None]
- pred_h = torch.exp(dh) * heights[:, None]
- pred_boxes = torch.zeros_like(deltas)
- pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w # x1
- pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h # y1
- pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w # x2
- pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h # y2
- return pred_boxes
- class DynamicConv(nn.Module):
- def __init__(self,
- feat_channels: int,
- dynamic_dim: int = 64,
- dynamic_num: int = 2,
- pooler_resolution: int = 7) -> None:
- super().__init__()
- self.feat_channels = feat_channels
- self.dynamic_dim = dynamic_dim
- self.dynamic_num = dynamic_num
- self.num_params = self.feat_channels * self.dynamic_dim
- self.dynamic_layer = nn.Linear(self.feat_channels,
- self.dynamic_num * self.num_params)
- self.norm1 = nn.LayerNorm(self.dynamic_dim)
- self.norm2 = nn.LayerNorm(self.feat_channels)
- self.activation = nn.ReLU(inplace=True)
- num_output = self.feat_channels * pooler_resolution**2
- self.out_layer = nn.Linear(num_output, self.feat_channels)
- self.norm3 = nn.LayerNorm(self.feat_channels)
- def forward(self, pro_features: Tensor, roi_features: Tensor) -> Tensor:
- """Forward function.
- Args:
- pro_features: (1, N * num_boxes, self.feat_channels)
- roi_features: (49, N * num_boxes, self.feat_channels)
- Returns:
- """
- features = roi_features.permute(1, 0, 2)
- parameters = self.dynamic_layer(pro_features).permute(1, 0, 2)
- param1 = parameters[:, :, :self.num_params].view(
- -1, self.feat_channels, self.dynamic_dim)
- param2 = parameters[:, :,
- self.num_params:].view(-1, self.dynamic_dim,
- self.feat_channels)
- features = torch.bmm(features, param1)
- features = self.norm1(features)
- features = self.activation(features)
- features = torch.bmm(features, param2)
- features = self.norm2(features)
- features = self.activation(features)
- features = features.flatten(1)
- features = self.out_layer(features)
- features = self.norm3(features)
- features = self.activation(features)
- return features
|