Shortcuts

mmyolo.models.task_modules.assigners.batch_task_aligned_assigner 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from mmyolo.models.losses import bbox_overlaps
from mmyolo.registry import TASK_UTILS
from .utils import (select_candidates_in_gts, select_highest_overlaps,
                    yolov6_iou_calculator)


[文档]@TASK_UTILS.register_module() class BatchTaskAlignedAssigner(nn.Module): """This code referenced to https://github.com/meituan/YOLOv6/blob/main/yolov6/ assigners/tal_assigner.py. Batch Task aligned assigner base on the paper: `TOOD: Task-aligned One-stage Object Detection. <https://arxiv.org/abs/2108.07755>`_. Assign a corresponding gt bboxes or background to a batch of predicted bboxes. Each bbox will be assigned with `0` or a positive integer indicating the ground truth index. - 0: negative sample, no assigned gt - positive integer: positive sample, index (1-based) of assigned gt Args: num_classes (int): number of class topk (int): number of bbox selected in each level alpha (float): Hyper-parameters related to alignment_metrics. Defaults to 1.0 beta (float): Hyper-parameters related to alignment_metrics. Defaults to 6. eps (float): Eps to avoid log(0). Default set to 1e-9 use_ciou (bool): Whether to use ciou while calculating iou. Defaults to False. """ def __init__(self, num_classes: int, topk: int = 13, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-7, use_ciou: bool = False): super().__init__() self.num_classes = num_classes self.topk = topk self.alpha = alpha self.beta = beta self.eps = eps self.use_ciou = use_ciou
[文档] @torch.no_grad() def forward( self, pred_bboxes: Tensor, pred_scores: Tensor, priors: Tensor, gt_labels: Tensor, gt_bboxes: Tensor, pad_bbox_flag: Tensor, ) -> dict: """Assign gt to bboxes. The assignment is done in following steps 1. compute alignment metric between all bbox (bbox of all pyramid levels) and gt 2. select top-k bbox as candidates for each gt 3. limit the positive sample's center in gt (because the anchor-free detector only can predict positive distance) Args: pred_bboxes (Tensor): Predict bboxes, shape(batch_size, num_priors, 4) pred_scores (Tensor): Scores of predict bboxes, shape(batch_size, num_priors, num_classes) priors (Tensor): Model priors, shape (num_priors, 4) gt_labels (Tensor): Ground true labels, shape(batch_size, num_gt, 1) gt_bboxes (Tensor): Ground true bboxes, shape(batch_size, num_gt, 4) pad_bbox_flag (Tensor): Ground truth bbox mask, 1 means bbox, 0 means no bbox, shape(batch_size, num_gt, 1) Returns: assigned_result (dict) Assigned result: assigned_labels (Tensor): Assigned labels, shape(batch_size, num_priors) assigned_bboxes (Tensor): Assigned boxes, shape(batch_size, num_priors, 4) assigned_scores (Tensor): Assigned scores, shape(batch_size, num_priors, num_classes) fg_mask_pre_prior (Tensor): Force ground truth matching mask, shape(batch_size, num_priors) """ # (num_priors, 4) -> (num_priors, 2) priors = priors[:, :2] batch_size = pred_scores.size(0) num_gt = gt_bboxes.size(1) assigned_result = { 'assigned_labels': gt_bboxes.new_full(pred_scores[..., 0].shape, self.num_classes), 'assigned_bboxes': gt_bboxes.new_full(pred_bboxes.shape, 0), 'assigned_scores': gt_bboxes.new_full(pred_scores.shape, 0), 'fg_mask_pre_prior': gt_bboxes.new_full(pred_scores[..., 0].shape, 0) } if num_gt == 0: return assigned_result pos_mask, alignment_metrics, overlaps = self.get_pos_mask( pred_bboxes, pred_scores, priors, gt_labels, gt_bboxes, pad_bbox_flag, batch_size, num_gt) (assigned_gt_idxs, fg_mask_pre_prior, pos_mask) = select_highest_overlaps(pos_mask, overlaps, num_gt) # assigned target assigned_labels, assigned_bboxes, assigned_scores = self.get_targets( gt_labels, gt_bboxes, assigned_gt_idxs, fg_mask_pre_prior, batch_size, num_gt) # normalize alignment_metrics *= pos_mask pos_align_metrics = alignment_metrics.max(axis=-1, keepdim=True)[0] pos_overlaps = (overlaps * pos_mask).max(axis=-1, keepdim=True)[0] norm_align_metric = ( alignment_metrics * pos_overlaps / (pos_align_metrics + self.eps)).max(-2)[0].unsqueeze(-1) assigned_scores = assigned_scores * norm_align_metric assigned_result['assigned_labels'] = assigned_labels assigned_result['assigned_bboxes'] = assigned_bboxes assigned_result['assigned_scores'] = assigned_scores assigned_result['fg_mask_pre_prior'] = fg_mask_pre_prior.bool() return assigned_result
[文档] def get_pos_mask(self, pred_bboxes: Tensor, pred_scores: Tensor, priors: Tensor, gt_labels: Tensor, gt_bboxes: Tensor, pad_bbox_flag: Tensor, batch_size: int, num_gt: int) -> Tuple[Tensor, Tensor, Tensor]: """Get possible mask. Args: pred_bboxes (Tensor): Predict bboxes, shape(batch_size, num_priors, 4) pred_scores (Tensor): Scores of predict bbox, shape(batch_size, num_priors, num_classes) priors (Tensor): Model priors, shape (num_priors, 2) gt_labels (Tensor): Ground true labels, shape(batch_size, num_gt, 1) gt_bboxes (Tensor): Ground true bboxes, shape(batch_size, num_gt, 4) pad_bbox_flag (Tensor): Ground truth bbox mask, 1 means bbox, 0 means no bbox, shape(batch_size, num_gt, 1) batch_size (int): Batch size. num_gt (int): Number of ground truth. Returns: pos_mask (Tensor): Possible mask, shape(batch_size, num_gt, num_priors) alignment_metrics (Tensor): Alignment metrics, shape(batch_size, num_gt, num_priors) overlaps (Tensor): Overlaps of gt_bboxes and pred_bboxes, shape(batch_size, num_gt, num_priors) """ # Compute alignment metric between all bbox and gt alignment_metrics, overlaps = \ self.get_box_metrics(pred_bboxes, pred_scores, gt_labels, gt_bboxes, batch_size, num_gt) # get is_in_gts mask is_in_gts = select_candidates_in_gts(priors, gt_bboxes) # get topk_metric mask topk_metric = self.select_topk_candidates( alignment_metrics * is_in_gts, topk_mask=pad_bbox_flag.repeat([1, 1, self.topk]).bool()) # merge all mask to a final mask pos_mask = topk_metric * is_in_gts * pad_bbox_flag return pos_mask, alignment_metrics, overlaps
[文档] def get_box_metrics(self, pred_bboxes: Tensor, pred_scores: Tensor, gt_labels: Tensor, gt_bboxes: Tensor, batch_size: int, num_gt: int) -> Tuple[Tensor, Tensor]: """Compute alignment metric between all bbox and gt. Args: pred_bboxes (Tensor): Predict bboxes, shape(batch_size, num_priors, 4) pred_scores (Tensor): Scores of predict bbox, shape(batch_size, num_priors, num_classes) gt_labels (Tensor): Ground true labels, shape(batch_size, num_gt, 1) gt_bboxes (Tensor): Ground true bboxes, shape(batch_size, num_gt, 4) batch_size (int): Batch size. num_gt (int): Number of ground truth. Returns: alignment_metrics (Tensor): Align metric, shape(batch_size, num_gt, num_priors) overlaps (Tensor): Overlaps, shape(batch_size, num_gt, num_priors) """ pred_scores = pred_scores.permute(0, 2, 1) gt_labels = gt_labels.to(torch.long) idx = torch.zeros([2, batch_size, num_gt], dtype=torch.long) idx[0] = torch.arange(end=batch_size).view(-1, 1).repeat(1, num_gt) idx[1] = gt_labels.squeeze(-1) bbox_scores = pred_scores[idx[0], idx[1]] # TODO: need to replace the yolov6_iou_calculator function if self.use_ciou: overlaps = bbox_overlaps( pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(2), iou_mode='ciou', bbox_format='xyxy').clamp(0) else: overlaps = yolov6_iou_calculator(gt_bboxes, pred_bboxes) alignment_metrics = bbox_scores.pow(self.alpha) * overlaps.pow( self.beta) return alignment_metrics, overlaps
[文档] def select_topk_candidates(self, alignment_gt_metrics: Tensor, using_largest_topk: bool = True, topk_mask: Optional[Tensor] = None) -> Tensor: """Compute alignment metric between all bbox and gt. Args: alignment_gt_metrics (Tensor): Alignment metric of gt candidates, shape(batch_size, num_gt, num_priors) using_largest_topk (bool): Controls whether to using largest or smallest elements. topk_mask (Tensor): Topk mask, shape(batch_size, num_gt, self.topk) Returns: Tensor: Topk candidates mask, shape(batch_size, num_gt, num_priors) """ num_priors = alignment_gt_metrics.shape[-1] topk_metrics, topk_idxs = torch.topk( alignment_gt_metrics, self.topk, axis=-1, largest=using_largest_topk) if topk_mask is None: topk_mask = (topk_metrics.max(axis=-1, keepdim=True) > self.eps).tile([1, 1, self.topk]) topk_idxs = torch.where(topk_mask, topk_idxs, torch.zeros_like(topk_idxs)) is_in_topk = F.one_hot(topk_idxs, num_priors).sum(axis=-2) is_in_topk = torch.where(is_in_topk > 1, torch.zeros_like(is_in_topk), is_in_topk) return is_in_topk.to(alignment_gt_metrics.dtype)
[文档] def get_targets(self, gt_labels: Tensor, gt_bboxes: Tensor, assigned_gt_idxs: Tensor, fg_mask_pre_prior: Tensor, batch_size: int, num_gt: int) -> Tuple[Tensor, Tensor, Tensor]: """Get assigner info. Args: gt_labels (Tensor): Ground true labels, shape(batch_size, num_gt, 1) gt_bboxes (Tensor): Ground true bboxes, shape(batch_size, num_gt, 4) assigned_gt_idxs (Tensor): Assigned ground truth indexes, shape(batch_size, num_priors) fg_mask_pre_prior (Tensor): Force ground truth matching mask, shape(batch_size, num_priors) batch_size (int): Batch size. num_gt (int): Number of ground truth. Returns: assigned_labels (Tensor): Assigned labels, shape(batch_size, num_priors) assigned_bboxes (Tensor): Assigned bboxes, shape(batch_size, num_priors) assigned_scores (Tensor): Assigned scores, shape(batch_size, num_priors) """ # assigned target labels batch_ind = torch.arange( end=batch_size, dtype=torch.int64, device=gt_labels.device)[..., None] assigned_gt_idxs = assigned_gt_idxs + batch_ind * num_gt assigned_labels = gt_labels.long().flatten()[assigned_gt_idxs] # assigned target boxes assigned_bboxes = gt_bboxes.reshape([-1, 4])[assigned_gt_idxs] # assigned target scores assigned_labels[assigned_labels < 0] = 0 assigned_scores = F.one_hot(assigned_labels, self.num_classes) force_gt_scores_mask = fg_mask_pre_prior[:, :, None].repeat( 1, 1, self.num_classes) assigned_scores = torch.where(force_gt_scores_mask > 0, assigned_scores, torch.full_like(assigned_scores, 0)) return assigned_labels, assigned_bboxes, assigned_scores
Read the Docs v: latest
Versions
latest
stable
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.