Shortcuts

mmyolo.models.dense_heads.yolov5_ins_head 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import List, Optional, Sequence, Tuple, Union

import mmcv
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmdet.models.utils import filter_scores_and_topk, multi_apply
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy
from mmdet.utils import ConfigType, OptInstanceList
from mmengine.config import ConfigDict
from mmengine.dist import get_dist_info
from mmengine.model import BaseModule
from mmengine.structures import InstanceData
from torch import Tensor

from mmyolo.registry import MODELS
from ..utils import make_divisible
from .yolov5_head import YOLOv5Head, YOLOv5HeadModule


class ProtoModule(BaseModule):
    """Mask Proto module for segmentation models of YOLOv5.

    Args:
        in_channels (int): Number of channels in the input feature map.
        middle_channels (int): Number of channels in the middle feature map.
        mask_channels (int): Number of channels in the output mask feature
            map. This is the channel count of the mask.
        norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
            layer. Defaults to ``dict(type='BN', momentum=0.03, eps=0.001)``.
        act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
            Default: dict(type='SiLU', inplace=True).
    """

    def __init__(self,
                 *args,
                 in_channels: int = 32,
                 middle_channels: int = 256,
                 mask_channels: int = 32,
                 norm_cfg: ConfigType = dict(
                     type='BN', momentum=0.03, eps=0.001),
                 act_cfg: ConfigType = dict(type='SiLU', inplace=True),
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.conv1 = ConvModule(
            in_channels,
            middle_channels,
            kernel_size=3,
            padding=1,
            norm_cfg=norm_cfg,
            act_cfg=act_cfg)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv2 = ConvModule(
            middle_channels,
            middle_channels,
            kernel_size=3,
            padding=1,
            norm_cfg=norm_cfg,
            act_cfg=act_cfg)
        self.conv3 = ConvModule(
            middle_channels,
            mask_channels,
            kernel_size=1,
            norm_cfg=norm_cfg,
            act_cfg=act_cfg)

    def forward(self, x: Tensor) -> Tensor:
        return self.conv3(self.conv2(self.upsample(self.conv1(x))))


[文档]@MODELS.register_module() class YOLOv5InsHeadModule(YOLOv5HeadModule): """Detection and Instance Segmentation Head of YOLOv5. Args: num_classes (int): Number of categories excluding the background category. mask_channels (int): Number of channels in the mask feature map. This is the channel count of the mask. proto_channels (int): Number of channels in the proto feature map. widen_factor (float): Width multiplier, multiply number of channels in each layer by this amount. Defaults to 1.0. norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization layer. Defaults to ``dict(type='BN', momentum=0.03, eps=0.001)``. act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer. Default: dict(type='SiLU', inplace=True). """ def __init__(self, *args, num_classes: int, mask_channels: int = 32, proto_channels: int = 256, widen_factor: float = 1.0, norm_cfg: ConfigType = dict( type='BN', momentum=0.03, eps=0.001), act_cfg: ConfigType = dict(type='SiLU', inplace=True), **kwargs): self.mask_channels = mask_channels self.num_out_attrib_with_proto = 5 + num_classes + mask_channels self.proto_channels = make_divisible(proto_channels, widen_factor) self.norm_cfg = norm_cfg self.act_cfg = act_cfg super().__init__( *args, num_classes=num_classes, widen_factor=widen_factor, **kwargs) def _init_layers(self): """initialize conv layers in YOLOv5 Ins head.""" self.convs_pred = nn.ModuleList() for i in range(self.num_levels): conv_pred = nn.Conv2d( self.in_channels[i], self.num_base_priors * self.num_out_attrib_with_proto, 1) self.convs_pred.append(conv_pred) self.proto_pred = ProtoModule( in_channels=self.in_channels[0], middle_channels=self.proto_channels, mask_channels=self.mask_channels, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)
[文档] def forward(self, x: Tuple[Tensor]) -> Tuple[List]: """Forward features from the upstream network. Args: x (Tuple[Tensor]): Features from the upstream network, each is a 4D-tensor. Returns: Tuple[List]: A tuple of multi-level classification scores, bbox predictions, objectnesses, and mask predictions. """ assert len(x) == self.num_levels cls_scores, bbox_preds, objectnesses, coeff_preds = multi_apply( self.forward_single, x, self.convs_pred) mask_protos = self.proto_pred(x[0]) return cls_scores, bbox_preds, objectnesses, coeff_preds, mask_protos
[文档] def forward_single( self, x: Tensor, convs_pred: nn.Module) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Forward feature of a single scale level.""" pred_map = convs_pred(x) bs, _, ny, nx = pred_map.shape pred_map = pred_map.view(bs, self.num_base_priors, self.num_out_attrib_with_proto, ny, nx) cls_score = pred_map[:, :, 5:self.num_classes + 5, ...].reshape(bs, -1, ny, nx) bbox_pred = pred_map[:, :, :4, ...].reshape(bs, -1, ny, nx) objectness = pred_map[:, :, 4:5, ...].reshape(bs, -1, ny, nx) coeff_pred = pred_map[:, :, self.num_classes + 5:, ...].reshape(bs, -1, ny, nx) return cls_score, bbox_pred, objectness, coeff_pred
[文档]@MODELS.register_module() class YOLOv5InsHead(YOLOv5Head): """YOLOv5 Instance Segmentation and Detection head. Args: mask_overlap(bool): Defaults to True. loss_mask (:obj:`ConfigDict` or dict): Config of mask loss. loss_mask_weight (float): The weight of mask loss. """ def __init__(self, *args, mask_overlap: bool = True, loss_mask: ConfigType = dict( type='mmdet.CrossEntropyLoss', use_sigmoid=True, reduction='none'), loss_mask_weight=0.05, **kwargs): super().__init__(*args, **kwargs) self.mask_overlap = mask_overlap self.loss_mask: nn.Module = MODELS.build(loss_mask) self.loss_mask_weight = loss_mask_weight
[文档] def loss(self, x: Tuple[Tensor], batch_data_samples: Union[list, dict]) -> dict: """Perform forward propagation and loss calculation of the detection head on the features of the upstream network. Args: x (tuple[Tensor]): Features from the upstream network, each is a 4D-tensor. batch_data_samples (List[:obj:`DetDataSample`], dict): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. Returns: dict: A dictionary of loss components. """ if isinstance(batch_data_samples, list): # TODO: support non-fast version ins segmention raise NotImplementedError else: outs = self(x) # Fast version loss_inputs = outs + (batch_data_samples['bboxes_labels'], batch_data_samples['masks'], batch_data_samples['img_metas']) losses = self.loss_by_feat(*loss_inputs) return losses
[文档] def loss_by_feat( self, cls_scores: Sequence[Tensor], bbox_preds: Sequence[Tensor], objectnesses: Sequence[Tensor], coeff_preds: Sequence[Tensor], proto_preds: Tensor, batch_gt_instances: Sequence[InstanceData], batch_gt_masks: Sequence[Tensor], batch_img_metas: Sequence[dict], batch_gt_instances_ignore: OptInstanceList = None) -> dict: """Calculate the loss based on the features extracted by the detection head. Args: cls_scores (Sequence[Tensor]): Box scores for each scale level, each is a 4D-tensor, the channel number is num_priors * num_classes. bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale level, each is a 4D-tensor, the channel number is num_priors * 4. objectnesses (Sequence[Tensor]): Score factor for all scale level, each is a 4D-tensor, has shape (batch_size, 1, H, W). coeff_preds (Sequence[Tensor]): Mask coefficient for each scale level, each is a 4D-tensor, the channel number is num_priors * mask_channels. proto_preds (Tensor): Mask prototype features extracted from the mask head, has shape (batch_size, mask_channels, H, W). batch_gt_instances (Sequence[InstanceData]): Batch of gt_instance. It usually includes ``bboxes`` and ``labels`` attributes. batch_gt_masks (Sequence[Tensor]): Batch of gt_mask. batch_img_metas (Sequence[dict]): Meta information of each image, e.g., image size, scaling factor, etc. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): Batch of gt_instances_ignore. It includes ``bboxes`` attribute data that is ignored during training and testing. Defaults to None. Returns: dict[str, Tensor]: A dictionary of losses. """ # 1. Convert gt to norm format batch_targets_normed = self._convert_gt_to_norm_format( batch_gt_instances, batch_img_metas) device = cls_scores[0].device loss_cls = torch.zeros(1, device=device) loss_box = torch.zeros(1, device=device) loss_obj = torch.zeros(1, device=device) loss_mask = torch.zeros(1, device=device) scaled_factor = torch.ones(8, device=device) for i in range(self.num_levels): batch_size, _, h, w = bbox_preds[i].shape target_obj = torch.zeros_like(objectnesses[i]) # empty gt bboxes if batch_targets_normed.shape[1] == 0: loss_box += bbox_preds[i].sum() * 0 loss_cls += cls_scores[i].sum() * 0 loss_obj += self.loss_obj( objectnesses[i], target_obj) * self.obj_level_weights[i] loss_mask += coeff_preds[i].sum() * 0 continue priors_base_sizes_i = self.priors_base_sizes[i] # feature map scale whwh scaled_factor[2:6] = torch.tensor( bbox_preds[i].shape)[[3, 2, 3, 2]] # Scale batch_targets from range 0-1 to range 0-features_maps size. # (num_base_priors, num_bboxes, 8) batch_targets_scaled = batch_targets_normed * scaled_factor # 2. Shape match wh_ratio = batch_targets_scaled[..., 4:6] / priors_base_sizes_i[:, None] match_inds = torch.max( wh_ratio, 1 / wh_ratio).max(2)[0] < self.prior_match_thr batch_targets_scaled = batch_targets_scaled[match_inds] # no gt bbox matches anchor if batch_targets_scaled.shape[0] == 0: loss_box += bbox_preds[i].sum() * 0 loss_cls += cls_scores[i].sum() * 0 loss_obj += self.loss_obj( objectnesses[i], target_obj) * self.obj_level_weights[i] loss_mask += coeff_preds[i].sum() * 0 continue # 3. Positive samples with additional neighbors # check the left, up, right, bottom sides of the # targets grid, and determine whether assigned # them as positive samples as well. batch_targets_cxcy = batch_targets_scaled[:, 2:4] grid_xy = scaled_factor[[2, 3]] - batch_targets_cxcy left, up = ((batch_targets_cxcy % 1 < self.near_neighbor_thr) & (batch_targets_cxcy > 1)).T right, bottom = ((grid_xy % 1 < self.near_neighbor_thr) & (grid_xy > 1)).T offset_inds = torch.stack( (torch.ones_like(left), left, up, right, bottom)) batch_targets_scaled = batch_targets_scaled.repeat( (5, 1, 1))[offset_inds] retained_offsets = self.grid_offset.repeat(1, offset_inds.shape[1], 1)[offset_inds] # prepare pred results and positive sample indexes to # calculate class loss and bbox lo _chunk_targets = batch_targets_scaled.chunk(4, 1) img_class_inds, grid_xy, grid_wh,\ priors_targets_inds = _chunk_targets (priors_inds, targets_inds) = priors_targets_inds.long().T (img_inds, class_inds) = img_class_inds.long().T grid_xy_long = (grid_xy - retained_offsets * self.near_neighbor_thr).long() grid_x_inds, grid_y_inds = grid_xy_long.T bboxes_targets = torch.cat((grid_xy - grid_xy_long, grid_wh), 1) # 4. Calculate loss # bbox loss retained_bbox_pred = bbox_preds[i].reshape( batch_size, self.num_base_priors, -1, h, w)[img_inds, priors_inds, :, grid_y_inds, grid_x_inds] priors_base_sizes_i = priors_base_sizes_i[priors_inds] decoded_bbox_pred = self._decode_bbox_to_xywh( retained_bbox_pred, priors_base_sizes_i) loss_box_i, iou = self.loss_bbox(decoded_bbox_pred, bboxes_targets) loss_box += loss_box_i # obj loss iou = iou.detach().clamp(0) target_obj[img_inds, priors_inds, grid_y_inds, grid_x_inds] = iou.type(target_obj.dtype) loss_obj += self.loss_obj(objectnesses[i], target_obj) * self.obj_level_weights[i] # cls loss if self.num_classes > 1: pred_cls_scores = cls_scores[i].reshape( batch_size, self.num_base_priors, -1, h, w)[img_inds, priors_inds, :, grid_y_inds, grid_x_inds] target_class = torch.full_like(pred_cls_scores, 0.) target_class[range(batch_targets_scaled.shape[0]), class_inds] = 1. loss_cls += self.loss_cls(pred_cls_scores, target_class) else: loss_cls += cls_scores[i].sum() * 0 # mask regression retained_coeff_preds = coeff_preds[i].reshape( batch_size, self.num_base_priors, -1, h, w)[img_inds, priors_inds, :, grid_y_inds, grid_x_inds] _, c, mask_h, mask_w = proto_preds.shape if batch_gt_masks.shape[-2:] != (mask_h, mask_w): batch_gt_masks = F.interpolate( batch_gt_masks[None], (mask_h, mask_w), mode='nearest')[0] xywh_normed = batch_targets_scaled[:, 2:6] / scaled_factor[2:6] area_normed = xywh_normed[:, 2:].prod(1) xywh_scaled = xywh_normed * torch.tensor( proto_preds.shape, device=device)[[3, 2, 3, 2]] xyxy_scaled = bbox_cxcywh_to_xyxy(xywh_scaled) for bs in range(batch_size): match_inds = (img_inds == bs) # matching index if not match_inds.any(): continue if self.mask_overlap: mask_gti = torch.where( batch_gt_masks[bs][None] == targets_inds[match_inds].view(-1, 1, 1), 1.0, 0.0) else: mask_gti = batch_gt_masks[targets_inds][match_inds] mask_preds = (retained_coeff_preds[match_inds] @ proto_preds[bs].view(c, -1)).view( -1, mask_h, mask_w) loss_mask_full = self.loss_mask(mask_preds, mask_gti) loss_mask += ( self.crop_mask(loss_mask_full[None], xyxy_scaled[match_inds]).mean(dim=(2, 3)) / area_normed[match_inds]).mean() _, world_size = get_dist_info() return dict( loss_cls=loss_cls * batch_size * world_size, loss_obj=loss_obj * batch_size * world_size, loss_bbox=loss_box * batch_size * world_size, loss_mask=loss_mask * self.loss_mask_weight * world_size)
def _convert_gt_to_norm_format(self, batch_gt_instances: Sequence[InstanceData], batch_img_metas: Sequence[dict]) -> Tensor: """Add target_inds for instance segmentation.""" batch_targets_normed = super()._convert_gt_to_norm_format( batch_gt_instances, batch_img_metas) if self.mask_overlap: batch_size = len(batch_img_metas) target_inds = [] for i in range(batch_size): # find number of targets of each image num_gts = (batch_gt_instances[:, 0] == i).sum() # (num_anchor, num_gts) target_inds.append( torch.arange(num_gts, device=batch_gt_instances.device). float().view(1, num_gts).repeat(self.num_base_priors, 1) + 1) target_inds = torch.cat(target_inds, 1) else: num_gts = batch_gt_instances.shape[0] target_inds = torch.arange( num_gts, device=batch_gt_instances.device).float().view( 1, num_gts).repeat(self.num_base_priors, 1) batch_targets_normed = torch.cat( [batch_targets_normed, target_inds[..., None]], 2) return batch_targets_normed
[文档] def predict_by_feat(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], objectnesses: Optional[List[Tensor]] = None, coeff_preds: Optional[List[Tensor]] = None, proto_preds: Optional[Tensor] = None, batch_img_metas: Optional[List[dict]] = None, cfg: Optional[ConfigDict] = None, rescale: bool = True, with_nms: bool = True) -> List[InstanceData]: """Transform a batch of output features extracted from the head into bbox results. Note: When score_factors is not None, the cls_scores are usually multiplied by it then obtain the real score used in NMS. Args: cls_scores (list[Tensor]): Classification scores for all scale levels, each is a 4D-tensor, has shape (batch_size, num_priors * num_classes, H, W). bbox_preds (list[Tensor]): Box energies / deltas for all scale levels, each is a 4D-tensor, has shape (batch_size, num_priors * 4, H, W). objectnesses (list[Tensor], Optional): Score factor for all scale level, each is a 4D-tensor, has shape (batch_size, 1, H, W). coeff_preds (list[Tensor]): Mask coefficients predictions for all scale levels, each is a 4D-tensor, has shape (batch_size, mask_channels, H, W). proto_preds (Tensor): Mask prototype features extracted from the mask head, has shape (batch_size, mask_channels, H, W). batch_img_metas (list[dict], Optional): Batch image meta info. Defaults to None. cfg (ConfigDict, optional): Test / postprocessing configuration, if None, test_cfg would be used. Defaults to None. rescale (bool): If True, return boxes in original image space. Defaults to False. with_nms (bool): If True, do nms before return boxes. Defaults to True. Returns: list[:obj:`InstanceData`]: Object detection and instance segmentation results of each image after the post process. Each item usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). - masks (Tensor): Has a shape (num_instances, h, w). """ assert len(cls_scores) == len(bbox_preds) == len(coeff_preds) if objectnesses is None: with_objectnesses = False else: with_objectnesses = True assert len(cls_scores) == len(objectnesses) cfg = self.test_cfg if cfg is None else cfg cfg = copy.deepcopy(cfg) multi_label = cfg.multi_label multi_label &= self.num_classes > 1 cfg.multi_label = multi_label num_imgs = len(batch_img_metas) featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores] # If the shape does not change, use the previous mlvl_priors if featmap_sizes != self.featmap_sizes: self.mlvl_priors = self.prior_generator.grid_priors( featmap_sizes, dtype=cls_scores[0].dtype, device=cls_scores[0].device) self.featmap_sizes = featmap_sizes flatten_priors = torch.cat(self.mlvl_priors) mlvl_strides = [ flatten_priors.new_full( (featmap_size.numel() * self.num_base_priors, ), stride) for featmap_size, stride in zip(featmap_sizes, self.featmap_strides) ] flatten_stride = torch.cat(mlvl_strides) # flatten cls_scores, bbox_preds and objectness flatten_cls_scores = [ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes) for cls_score in cls_scores ] flatten_bbox_preds = [ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) for bbox_pred in bbox_preds ] flatten_coeff_preds = [ coeff_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.head_module.mask_channels) for coeff_pred in coeff_preds ] flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid() flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1) flatten_decoded_bboxes = self.bbox_coder.decode( flatten_priors.unsqueeze(0), flatten_bbox_preds, flatten_stride) flatten_coeff_preds = torch.cat(flatten_coeff_preds, dim=1) if with_objectnesses: flatten_objectness = [ objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1) for objectness in objectnesses ] flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid() else: flatten_objectness = [None for _ in range(len(featmap_sizes))] results_list = [] for (bboxes, scores, objectness, coeffs, mask_proto, img_meta) in zip(flatten_decoded_bboxes, flatten_cls_scores, flatten_objectness, flatten_coeff_preds, proto_preds, batch_img_metas): ori_shape = img_meta['ori_shape'] batch_input_shape = img_meta['batch_input_shape'] input_shape_h, input_shape_w = batch_input_shape if 'pad_param' in img_meta: pad_param = img_meta['pad_param'] input_shape_withoutpad = (input_shape_h - pad_param[0] - pad_param[1], input_shape_w - pad_param[2] - pad_param[3]) else: pad_param = None input_shape_withoutpad = batch_input_shape scale_factor = (input_shape_withoutpad[1] / ori_shape[1], input_shape_withoutpad[0] / ori_shape[0]) score_thr = cfg.get('score_thr', -1) # yolox_style does not require the following operations if objectness is not None and score_thr > 0 and not cfg.get( 'yolox_style', False): conf_inds = objectness > score_thr bboxes = bboxes[conf_inds, :] scores = scores[conf_inds, :] objectness = objectness[conf_inds] coeffs = coeffs[conf_inds] if objectness is not None: # conf = obj_conf * cls_conf scores *= objectness[:, None] # NOTE: Important coeffs *= objectness[:, None] if scores.shape[0] == 0: empty_results = InstanceData() empty_results.bboxes = bboxes empty_results.scores = scores[:, 0] empty_results.labels = scores[:, 0].int() h, w = ori_shape[:2] if rescale else img_meta['img_shape'][:2] empty_results.masks = torch.zeros( size=(0, h, w), dtype=torch.bool, device=bboxes.device) results_list.append(empty_results) continue nms_pre = cfg.get('nms_pre', 100000) if cfg.multi_label is False: scores, labels = scores.max(1, keepdim=True) scores, _, keep_idxs, results = filter_scores_and_topk( scores, score_thr, nms_pre, results=dict(labels=labels[:, 0], coeffs=coeffs)) labels = results['labels'] coeffs = results['coeffs'] else: out = filter_scores_and_topk( scores, score_thr, nms_pre, results=dict(coeffs=coeffs)) scores, labels, keep_idxs, filtered_results = out coeffs = filtered_results['coeffs'] results = InstanceData( scores=scores, labels=labels, bboxes=bboxes[keep_idxs], coeffs=coeffs) if cfg.get('yolox_style', False): # do not need max_per_img cfg.max_per_img = len(results) results = self._bbox_post_process( results=results, cfg=cfg, rescale=False, with_nms=with_nms, img_meta=img_meta) if len(results.bboxes): masks = self.process_mask(mask_proto, results.coeffs, results.bboxes, (input_shape_h, input_shape_w), True) if rescale: if pad_param is not None: # bbox minus pad param top_pad, _, left_pad, _ = pad_param results.bboxes -= results.bboxes.new_tensor( [left_pad, top_pad, left_pad, top_pad]) # mask crop pad param top, left = int(top_pad), int(left_pad) bottom, right = int(input_shape_h - top_pad), int(input_shape_w - left_pad) masks = masks[:, :, top:bottom, left:right] results.bboxes /= results.bboxes.new_tensor( scale_factor).repeat((1, 2)) fast_test = cfg.get('fast_test', False) if fast_test: masks = F.interpolate( masks, size=ori_shape, mode='bilinear', align_corners=False) masks = masks.squeeze(0) masks = masks > cfg.mask_thr_binary else: masks.gt_(cfg.mask_thr_binary) masks = torch.as_tensor(masks, dtype=torch.uint8) masks = masks[0].permute(1, 2, 0).contiguous().cpu().numpy() masks = mmcv.imresize(masks, (ori_shape[1], ori_shape[0])) if len(masks.shape) == 2: masks = masks[:, :, None] masks = torch.from_numpy(masks).permute(2, 0, 1) results.bboxes[:, 0::2].clamp_(0, ori_shape[1]) results.bboxes[:, 1::2].clamp_(0, ori_shape[0]) results.masks = masks.bool() results_list.append(results) else: h, w = ori_shape[:2] if rescale else img_meta['img_shape'][:2] results.masks = torch.zeros( size=(0, h, w), dtype=torch.bool, device=bboxes.device) results_list.append(results) return results_list
[文档] def process_mask(self, mask_proto: Tensor, mask_coeff_pred: Tensor, bboxes: Tensor, shape: Tuple[int, int], upsample: bool = False) -> Tensor: """Generate mask logits results. Args: mask_proto (Tensor): Mask prototype features. Has shape (num_instance, mask_channels). mask_coeff_pred (Tensor): Mask coefficients prediction for single image. Has shape (mask_channels, H, W) bboxes (Tensor): Tensor of the bbox. Has shape (num_instance, 4). shape (Tuple): Batch input shape of image. upsample (bool): Whether upsample masks results to batch input shape. Default to False. Return: Tensor: Instance segmentation masks for each instance. Has shape (num_instance, H, W). """ c, mh, mw = mask_proto.shape # CHW masks = ( mask_coeff_pred @ mask_proto.float().view(c, -1)).sigmoid().view( -1, mh, mw)[None] if upsample: masks = F.interpolate( masks, shape, mode='bilinear', align_corners=False) # 1CHW masks = self.crop_mask(masks, bboxes) return masks
[文档] def crop_mask(self, masks: Tensor, boxes: Tensor) -> Tensor: """Crop mask by the bounding box. Args: masks (Tensor): Predicted mask results. Has shape (1, num_instance, H, W). boxes (Tensor): Tensor of the bbox. Has shape (num_instance, 4). Returns: (torch.Tensor): The masks are being cropped to the bounding box. """ _, n, h, w = masks.shape x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) r = torch.arange( w, device=masks.device, dtype=x1.dtype)[None, None, None, :] # rows shape(1, 1, w, 1) c = torch.arange( h, device=masks.device, dtype=x1.dtype)[None, None, :, None] # cols shape(1, h, 1, 1) return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
Read the Docs v: latest
Versions
latest
stable
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.