mmyolo.models.losses.oks_loss 源代码
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
import torch.nn as nn
from torch import Tensor
from mmyolo.registry import MODELS
try:
from mmpose.datasets.datasets.utils import parse_pose_metainfo
except ImportError:
parse_pose_metainfo = None
[文档]@MODELS.register_module()
class OksLoss(nn.Module):
"""A PyTorch implementation of the Object Keypoint Similarity (OKS) loss as
described in the paper "YOLO-Pose: Enhancing YOLO for Multi Person Pose
Estimation Using Object Keypoint Similarity Loss" by Debapriya et al.
(2022).
The OKS loss is used for keypoint-based object recognition and consists
of a measure of the similarity between predicted and ground truth
keypoint locations, adjusted by the size of the object in the image.
The loss function takes as input the predicted keypoint locations, the
ground truth keypoint locations, a mask indicating which keypoints are
valid, and bounding boxes for the objects.
Args:
metainfo (Optional[str]): Path to a JSON file containing information
about the dataset's annotations.
loss_weight (float): Weight for the loss.
"""
def __init__(self,
metainfo: Optional[str] = None,
loss_weight: float = 1.0):
super().__init__()
if metainfo is not None:
if parse_pose_metainfo is None:
raise ImportError(
'Please run "mim install -r requirements/mmpose.txt" '
'to install mmpose first for OksLossn.')
metainfo = parse_pose_metainfo(dict(from_file=metainfo))
sigmas = metainfo.get('sigmas', None)
if sigmas is not None:
self.register_buffer('sigmas', torch.as_tensor(sigmas))
self.loss_weight = loss_weight
[文档] def forward(self,
output: Tensor,
target: Tensor,
target_weights: Tensor,
bboxes: Optional[Tensor] = None) -> Tensor:
oks = self.compute_oks(output, target, target_weights, bboxes)
loss = 1 - oks
return loss * self.loss_weight
[文档] def compute_oks(self,
output: Tensor,
target: Tensor,
target_weights: Tensor,
bboxes: Optional[Tensor] = None) -> Tensor:
"""Calculates the OKS loss.
Args:
output (Tensor): Predicted keypoints in shape N x k x 2, where N
is batch size, k is the number of keypoints, and 2 are the
xy coordinates.
target (Tensor): Ground truth keypoints in the same shape as
output.
target_weights (Tensor): Mask of valid keypoints in shape N x k,
with 1 for valid and 0 for invalid.
bboxes (Optional[Tensor]): Bounding boxes in shape N x 4,
where 4 are the xyxy coordinates.
Returns:
Tensor: The calculated OKS loss.
"""
dist = torch.norm(output - target, dim=-1)
if hasattr(self, 'sigmas'):
sigmas = self.sigmas.reshape(*((1, ) * (dist.ndim - 1)), -1)
dist = dist / sigmas
if bboxes is not None:
area = torch.norm(bboxes[..., 2:] - bboxes[..., :2], dim=-1)
dist = dist / area.clip(min=1e-8).unsqueeze(-1)
return (torch.exp(-dist.pow(2) / 2) * target_weights).sum(
dim=-1) / target_weights.sum(dim=-1).clip(min=1e-8)