mmyolo.models.dense_heads.yolox_pose_head 源代码
# Copyright (c) OpenMMLab. All rights reserved.
from collections import defaultdict
from typing import List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
from mmcv.ops import batched_nms
from mmdet.models.utils import filter_scores_and_topk
from mmdet.utils import ConfigType, OptInstanceList
from mmengine.config import ConfigDict
from mmengine.model import ModuleList, bias_init_with_prob
from mmengine.structures import InstanceData
from torch import Tensor
from mmyolo.registry import MODELS
from ..utils import OutputSaveFunctionWrapper, OutputSaveObjectWrapper
from .yolox_head import YOLOXHead, YOLOXHeadModule
[文档]@MODELS.register_module()
class YOLOXPoseHeadModule(YOLOXHeadModule):
"""YOLOXPoseHeadModule serves as a head module for `YOLOX-Pose`.
In comparison to `YOLOXHeadModule`, this module introduces branches for
keypoint prediction.
"""
def __init__(self, num_keypoints: int, *args, **kwargs):
self.num_keypoints = num_keypoints
super().__init__(*args, **kwargs)
def _init_layers(self):
"""Initializes the layers in the head module."""
super()._init_layers()
# The pose branch requires additional layers for precise regression
self.stacked_convs *= 2
# Create separate layers for each level of feature maps
pose_convs, offsets_preds, vis_preds = [], [], []
for _ in self.featmap_strides:
pose_convs.append(self._build_stacked_convs())
offsets_preds.append(
nn.Conv2d(self.feat_channels, self.num_keypoints * 2, 1))
vis_preds.append(
nn.Conv2d(self.feat_channels, self.num_keypoints, 1))
self.multi_level_pose_convs = ModuleList(pose_convs)
self.multi_level_conv_offsets = ModuleList(offsets_preds)
self.multi_level_conv_vis = ModuleList(vis_preds)
[文档] def init_weights(self):
"""Initialize weights of the head."""
super().init_weights()
# Use prior in model initialization to improve stability
bias_init = bias_init_with_prob(0.01)
for conv_vis in self.multi_level_conv_vis:
conv_vis.bias.data.fill_(bias_init)
[文档] def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
"""Forward features from the upstream network."""
offsets_pred, vis_pred = [], []
for i in range(len(x)):
pose_feat = self.multi_level_pose_convs[i](x[i])
offsets_pred.append(self.multi_level_conv_offsets[i](pose_feat))
vis_pred.append(self.multi_level_conv_vis[i](pose_feat))
return (*super().forward(x), offsets_pred, vis_pred)
[文档]@MODELS.register_module()
class YOLOXPoseHead(YOLOXHead):
"""YOLOXPoseHead head used in `YOLO-Pose.
<https://arxiv.org/abs/2204.06806>`_.
Args:
loss_pose (ConfigDict, optional): Config of keypoint OKS loss.
"""
def __init__(
self,
loss_pose: Optional[ConfigType] = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.loss_pose = MODELS.build(loss_pose)
self.num_keypoints = self.head_module.num_keypoints
# set up buffers to save variables generated in methods of
# the class's base class.
self._log = defaultdict(list)
self.sampler = OutputSaveObjectWrapper(self.sampler)
# ensure that the `sigmas` in self.assigner.oks_calculator
# is on the same device as the model
if hasattr(self.assigner, 'oks_calculator'):
self.add_module('assigner_oks_calculator',
self.assigner.oks_calculator)
def _clear(self):
"""Clear variable buffers."""
self.sampler.clear()
self._log.clear()
[文档] def loss(self, x: Tuple[Tensor], batch_data_samples: Union[list,
dict]) -> dict:
if isinstance(batch_data_samples, list):
losses = super().loss(x, batch_data_samples)
else:
outs = self(x)
# Fast version
loss_inputs = outs + (batch_data_samples['bboxes_labels'],
batch_data_samples['keypoints'],
batch_data_samples['keypoints_visible'],
batch_data_samples['img_metas'])
losses = self.loss_by_feat(*loss_inputs)
return losses
[文档] def loss_by_feat(
self,
cls_scores: Sequence[Tensor],
bbox_preds: Sequence[Tensor],
objectnesses: Sequence[Tensor],
kpt_preds: Sequence[Tensor],
vis_preds: Sequence[Tensor],
batch_gt_instances: Tensor,
batch_gt_keypoints: Tensor,
batch_gt_keypoints_visible: Tensor,
batch_img_metas: Sequence[dict],
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
"""Calculate the loss based on the features extracted by the detection
head.
In addition to the base class method, keypoint losses are also
calculated in this method.
"""
self._clear()
batch_gt_instances = self.gt_kps_instances_preprocess(
batch_gt_instances, batch_gt_keypoints, batch_gt_keypoints_visible,
len(batch_img_metas))
# collect keypoints coordinates and visibility from model predictions
kpt_preds = torch.cat([
kpt_pred.flatten(2).permute(0, 2, 1).contiguous()
for kpt_pred in kpt_preds
],
dim=1)
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes,
dtype=cls_scores[0].dtype,
device=cls_scores[0].device,
with_stride=True)
grid_priors = torch.cat(mlvl_priors)
flatten_kpts = self.decode_pose(grid_priors[..., :2], kpt_preds,
grid_priors[..., 2])
vis_preds = torch.cat([
vis_pred.flatten(2).permute(0, 2, 1).contiguous()
for vis_pred in vis_preds
],
dim=1)
# compute detection losses and collect targets for keypoints
# predictions simultaneously
self._log['pred_keypoints'] = list(flatten_kpts.detach().split(
1, dim=0))
self._log['pred_keypoints_vis'] = list(vis_preds.detach().split(
1, dim=0))
losses = super().loss_by_feat(cls_scores, bbox_preds, objectnesses,
batch_gt_instances, batch_img_metas,
batch_gt_instances_ignore)
kpt_targets, vis_targets = [], []
sampling_results = self.sampler.log['sample']
sampling_result_idx = 0
for gt_instances in batch_gt_instances:
if len(gt_instances) > 0:
sampling_result = sampling_results[sampling_result_idx]
kpt_target = gt_instances['keypoints'][
sampling_result.pos_assigned_gt_inds]
vis_target = gt_instances['keypoints_visible'][
sampling_result.pos_assigned_gt_inds]
sampling_result_idx += 1
kpt_targets.append(kpt_target)
vis_targets.append(vis_target)
if len(kpt_targets) > 0:
kpt_targets = torch.cat(kpt_targets, 0)
vis_targets = torch.cat(vis_targets, 0)
# compute keypoint losses
if len(kpt_targets) > 0:
vis_targets = (vis_targets > 0).float()
pos_masks = torch.cat(self._log['foreground_mask'], 0)
bbox_targets = torch.cat(self._log['bbox_target'], 0)
loss_kpt = self.loss_pose(
flatten_kpts.view(-1, self.num_keypoints, 2)[pos_masks],
kpt_targets, vis_targets, bbox_targets)
loss_vis = self.loss_cls(
vis_preds.view(-1, self.num_keypoints)[pos_masks],
vis_targets) / vis_targets.sum()
else:
loss_kpt = kpt_preds.sum() * 0
loss_vis = vis_preds.sum() * 0
losses.update(dict(loss_kpt=loss_kpt, loss_vis=loss_vis))
self._clear()
return losses
@torch.no_grad()
def _get_targets_single(
self,
priors: Tensor,
cls_preds: Tensor,
decoded_bboxes: Tensor,
objectness: Tensor,
gt_instances: InstanceData,
img_meta: dict,
gt_instances_ignore: Optional[InstanceData] = None) -> tuple:
"""Calculates targets for a single image, and saves them to the log.
This method is similar to the _get_targets_single method in the base
class, but additionally saves the foreground mask and bbox targets to
the log.
"""
# Construct a combined representation of bboxes and keypoints to
# ensure keypoints are also involved in the positive sample
# assignment process
kpt = self._log['pred_keypoints'].pop(0).squeeze(0)
kpt_vis = self._log['pred_keypoints_vis'].pop(0).squeeze(0)
kpt = torch.cat((kpt, kpt_vis.unsqueeze(-1)), dim=-1)
decoded_bboxes = torch.cat((decoded_bboxes, kpt.flatten(1)), dim=1)
targets = super()._get_targets_single(priors, cls_preds,
decoded_bboxes, objectness,
gt_instances, img_meta,
gt_instances_ignore)
self._log['foreground_mask'].append(targets[0])
self._log['bbox_target'].append(targets[3])
return targets
[文档] def predict_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
objectnesses: Optional[List[Tensor]] = None,
kpt_preds: Optional[List[Tensor]] = None,
vis_preds: Optional[List[Tensor]] = None,
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = True,
with_nms: bool = True) -> List[InstanceData]:
"""Transform a batch of output features extracted by the head into bbox
and keypoint results.
In addition to the base class method, keypoint predictions are also
calculated in this method.
"""
"""calculate predicted bboxes and get the kept instances indices.
use OutputSaveFunctionWrapper as context manager to obtain
intermediate output from a parent class without copying a
arge block of code
"""
with OutputSaveFunctionWrapper(
filter_scores_and_topk,
super().predict_by_feat.__globals__) as outputs_1:
with OutputSaveFunctionWrapper(
batched_nms,
super()._bbox_post_process.__globals__) as outputs_2:
results_list = super().predict_by_feat(cls_scores, bbox_preds,
objectnesses,
batch_img_metas, cfg,
rescale, with_nms)
keep_indices_topk = [
out[2][:cfg.max_per_img] for out in outputs_1
]
keep_indices_nms = [
out[1][:cfg.max_per_img] for out in outputs_2
]
num_imgs = len(batch_img_metas)
# recover keypoints coordinates from model predictions
featmap_sizes = [vis_pred.shape[2:] for vis_pred in vis_preds]
priors = torch.cat(self.mlvl_priors)
strides = [
priors.new_full((featmap_size.numel() * self.num_base_priors, ),
stride) for featmap_size, stride in zip(
featmap_sizes, self.featmap_strides)
]
strides = torch.cat(strides)
kpt_preds = torch.cat([
kpt_pred.permute(0, 2, 3, 1).reshape(
num_imgs, -1, self.num_keypoints * 2) for kpt_pred in kpt_preds
],
dim=1)
flatten_decoded_kpts = self.decode_pose(priors, kpt_preds, strides)
vis_preds = torch.cat([
vis_pred.permute(0, 2, 3, 1).reshape(
num_imgs, -1, self.num_keypoints) for vis_pred in vis_preds
],
dim=1).sigmoid()
# select keypoints predictions according to bbox scores and nms result
keep_indices_nms_idx = 0
for pred_instances, kpts, kpts_vis, img_meta, keep_idxs \
in zip(
results_list, flatten_decoded_kpts, vis_preds,
batch_img_metas, keep_indices_topk):
pred_instances.bbox_scores = pred_instances.scores
if len(pred_instances) == 0:
pred_instances.keypoints = kpts[:0]
pred_instances.keypoint_scores = kpts_vis[:0]
continue
kpts = kpts[keep_idxs]
kpts_vis = kpts_vis[keep_idxs]
if rescale:
pad_param = img_meta.get('img_meta', None)
scale_factor = img_meta['scale_factor']
if pad_param is not None:
kpts -= kpts.new_tensor([pad_param[2], pad_param[0]])
kpts /= kpts.new_tensor(scale_factor).repeat(
(1, self.num_keypoints, 1))
keep_idxs_nms = keep_indices_nms[keep_indices_nms_idx]
kpts = kpts[keep_idxs_nms]
kpts_vis = kpts_vis[keep_idxs_nms]
keep_indices_nms_idx += 1
pred_instances.keypoints = kpts
pred_instances.keypoint_scores = kpts_vis
results_list = [r.numpy() for r in results_list]
return results_list
[文档] def decode_pose(self, grids: torch.Tensor, offsets: torch.Tensor,
strides: Union[torch.Tensor, int]) -> torch.Tensor:
"""Decode regression offsets to keypoints.
Args:
grids (torch.Tensor): The coordinates of the feature map grids.
offsets (torch.Tensor): The predicted offset of each keypoint
relative to its corresponding grid.
strides (torch.Tensor | int): The stride of the feature map for
each instance.
Returns:
torch.Tensor: The decoded keypoints coordinates.
"""
if isinstance(strides, int):
strides = torch.tensor([strides]).to(offsets)
strides = strides.reshape(1, -1, 1, 1)
offsets = offsets.reshape(*offsets.shape[:2], -1, 2)
xy_coordinates = (offsets[..., :2] * strides) + grids.unsqueeze(1)
return xy_coordinates
[文档] @staticmethod
def gt_kps_instances_preprocess(batch_gt_instances: Tensor,
batch_gt_keypoints,
batch_gt_keypoints_visible,
batch_size: int) -> List[InstanceData]:
"""Split batch_gt_instances with batch size.
Args:
batch_gt_instances (Tensor): Ground truth
a 2D-Tensor for whole batch, shape [all_gt_bboxes, 6]
batch_size (int): Batch size.
Returns:
List: batch gt instances data, shape [batch_size, InstanceData]
"""
# faster version
batch_instance_list = []
for i in range(batch_size):
batch_gt_instance_ = InstanceData()
single_batch_instance = \
batch_gt_instances[batch_gt_instances[:, 0] == i, :]
keypoints = \
batch_gt_keypoints[batch_gt_instances[:, 0] == i, :]
keypoints_visible = \
batch_gt_keypoints_visible[batch_gt_instances[:, 0] == i, :]
batch_gt_instance_.bboxes = single_batch_instance[:, 2:]
batch_gt_instance_.labels = single_batch_instance[:, 1]
batch_gt_instance_.keypoints = keypoints
batch_gt_instance_.keypoints_visible = keypoints_visible
batch_instance_list.append(batch_gt_instance_)
return batch_instance_list
[文档] @staticmethod
def gt_instances_preprocess(batch_gt_instances: List[InstanceData], *args,
**kwargs) -> List[InstanceData]:
return batch_gt_instances