Shortcuts

mmyolo.models.backbones.base_backbone 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import List, Sequence, Union

import torch
import torch.nn as nn
from mmcv.cnn import build_plugin_layer
from mmdet.utils import ConfigType, OptMultiConfig
from mmengine.model import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm

from mmyolo.registry import MODELS


[文档]@MODELS.register_module() class BaseBackbone(BaseModule, metaclass=ABCMeta): """BaseBackbone backbone used in YOLO series. .. code:: text Backbone model structure diagram +-----------+ | input | +-----------+ v +-----------+ | stem | | layer | +-----------+ v +-----------+ | stage | | layer 1 | +-----------+ v +-----------+ | stage | | layer 2 | +-----------+ v ...... v +-----------+ | stage | | layer n | +-----------+ In P5 model, n=4 In P6 model, n=5 Args: arch_setting (list): Architecture of BaseBackbone. plugins (list[dict]): List of plugins for stages, each dict contains: - cfg (dict, required): Cfg dict to build plugin. - stages (tuple[bool], optional): Stages to apply plugin, length should be same as 'num_stages'. deepen_factor (float): Depth multiplier, multiply number of blocks in CSP layer by this amount. Defaults to 1.0. widen_factor (float): Width multiplier, multiply number of channels in each layer by this amount. Defaults to 1.0. input_channels: Number of input image channels. Defaults to 3. out_indices (Sequence[int]): Output from which stages. Defaults to (2, 3, 4). frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Defaults to -1. norm_cfg (dict): Dictionary to construct and config norm layer. Defaults to None. act_cfg (dict): Config dict for activation layer. Defaults to None. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Defaults to False. init_cfg (dict or list[dict], optional): Initialization config dict. Defaults to None. """ def __init__(self, arch_setting: list, deepen_factor: float = 1.0, widen_factor: float = 1.0, input_channels: int = 3, out_indices: Sequence[int] = (2, 3, 4), frozen_stages: int = -1, plugins: Union[dict, List[dict]] = None, norm_cfg: ConfigType = None, act_cfg: ConfigType = None, norm_eval: bool = False, init_cfg: OptMultiConfig = None): super().__init__(init_cfg) self.num_stages = len(arch_setting) self.arch_setting = arch_setting assert set(out_indices).issubset( i for i in range(len(arch_setting) + 1)) if frozen_stages not in range(-1, len(arch_setting) + 1): raise ValueError('"frozen_stages" must be in range(-1, ' 'len(arch_setting) + 1). But received ' f'{frozen_stages}') self.input_channels = input_channels self.out_indices = out_indices self.frozen_stages = frozen_stages self.widen_factor = widen_factor self.deepen_factor = deepen_factor self.norm_eval = norm_eval self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.plugins = plugins self.stem = self.build_stem_layer() self.layers = ['stem'] for idx, setting in enumerate(arch_setting): stage = [] stage += self.build_stage_layer(idx, setting) if plugins is not None: stage += self.make_stage_plugins(plugins, idx, setting) self.add_module(f'stage{idx + 1}', nn.Sequential(*stage)) self.layers.append(f'stage{idx + 1}')
[文档] @abstractmethod def build_stem_layer(self): """Build a stem layer.""" pass
[文档] @abstractmethod def build_stage_layer(self, stage_idx: int, setting: list): """Build a stage layer. Args: stage_idx (int): The index of a stage layer. setting (list): The architecture setting of a stage layer. """ pass
[文档] def make_stage_plugins(self, plugins, stage_idx, setting): """Make plugins for backbone ``stage_idx`` th stage. Currently we support to insert ``context_block``, ``empirical_attention_block``, ``nonlocal_block``, ``dropout_block`` into the backbone. An example of plugins format could be: Examples: >>> plugins=[ ... dict(cfg=dict(type='xxx', arg1='xxx'), ... stages=(False, True, True, True)), ... dict(cfg=dict(type='yyy'), ... stages=(True, True, True, True)), ... ] >>> model = YOLOv5CSPDarknet() >>> stage_plugins = model.make_stage_plugins(plugins, 0, setting) >>> assert len(stage_plugins) == 1 Suppose ``stage_idx=0``, the structure of blocks in the stage would be: .. code-block:: none conv1 -> conv2 -> conv3 -> yyy Suppose ``stage_idx=1``, the structure of blocks in the stage would be: .. code-block:: none conv1 -> conv2 -> conv3 -> xxx -> yyy Args: plugins (list[dict]): List of plugins cfg to build. The postfix is required if multiple same type plugins are inserted. stage_idx (int): Index of stage to build If stages is missing, the plugin would be applied to all stages. setting (list): The architecture setting of a stage layer. Returns: list[nn.Module]: Plugins for current stage """ # TODO: It is not general enough to support any channel and needs # to be refactored in_channels = int(setting[1] * self.widen_factor) plugin_layers = [] for plugin in plugins: plugin = plugin.copy() stages = plugin.pop('stages', None) assert stages is None or len(stages) == self.num_stages if stages is None or stages[stage_idx]: name, layer = build_plugin_layer( plugin['cfg'], in_channels=in_channels) plugin_layers.append(layer) return plugin_layers
def _freeze_stages(self): """Freeze the parameters of the specified stage so that they are no longer updated.""" if self.frozen_stages >= 0: for i in range(self.frozen_stages + 1): m = getattr(self, self.layers[i]) m.eval() for param in m.parameters(): param.requires_grad = False
[文档] def train(self, mode: bool = True): """Convert the model into training mode while keep normalization layer frozen.""" super().train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): if isinstance(m, _BatchNorm): m.eval()
[文档] def forward(self, x: torch.Tensor) -> tuple: """Forward batch_inputs from the data_preprocessor.""" outs = [] for i, layer_name in enumerate(self.layers): layer = getattr(self, layer_name) x = layer(x) if i in self.out_indices: outs.append(x) return tuple(outs)
Read the Docs v: latest
Versions
latest
stable
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.