基于mmdetection源码:Faster R-CNN算法解读

Faster R-CNN算法源码解读

前言

OpenMMLab学习资料:OpenMMLab 模块化设计背后的功臣

对于mmdetection的学习,可以直接阅读官方的文档,现在官方提供了中文文档,非常详细,点击mmdetection中文文档阅读。

关于基于mmdetection框架实现Faster R-CNN算法,可以先阅读官方在知乎上发布的文章,轻松掌握 MMDetection 中常用算法(二):Faster R-CNN|Mask R-CNN,本篇博客是对其进行从源码实现层面进行补充解读。

mmdetection代码框架

通过下面命令拉取最新的代码

1
git clone git@github.com:open-mmlab/mmdetection.git

代码仓库主分支为master分支,目前计算机视觉领域中目标检测算法正在快速发展,mmdetection也不断地实现新的算法,最新实现的算法查看CHANGELOG

如果觉得master代码更新比较频繁,影响自己基于mmdetection的算法实现,也可以基于最新发布的版本创建自己的分支,如下命令:

1
git checkout -b newbranch v2.25.0 # 截止2022.06.19发布的最新稳定版

目前还是推荐直接使用master代码,并且要经常 git pull 拉取最新的代码,这样最新的一些组件也会更新,方便基于最新的组件搭建自己的网络算法。

之前阅读的YOLO v4论文中,作者将目标检测抽象为下面几个部分,mmdetection的代码也是这样抽象的:

现代基于卷积神经网络目标检测框架

下面代码分析基于mmdetection的源码当前最近commit: ca11860f4f3c3ca2ce8340e2686eeaec05b29111 ,时间 2022.06.20。

下面是mmdetection代码框架的结构:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
.
├── CITATION.cff
├── LICENSE
├── MANIFEST.in
├── README.md
├── README_zh-CN.md
├── andy_README.md
├── configs # 存放所有的配置文件,可以自己写配置文件,也可以继承自_base_下的四种配置文件
├── demo
├── docker
├── docs
├── mmdet # mmdeteetion的核心代码部分,其他工具都依赖于该部分的代码
├── model-index.yml
├── pytest.ini
├── requirements
├── requirements.txt
├── resources
├── setup.cfg
├── setup.py
├── tests # 集成测试相关代码
└── tools # 提供训练、测试等工具代码

以下是configs目录下的配置文件说明:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
configs
├── _base_
│   ├── datasets # 数据集加载,不同的数据集格式获取数据代码。
│   ├── default_runtime.py # 默认的运行时配置文件,主要配置包括权重保存频率、日志频率,日志等级等信息。
│   ├── models # 不同的目标检测模型配置文件
│   └── schedules # 各种训练策略配置文件
├── albu_example  # ...其他目录文件:1)继承自上面的四种文件,并做一些针对性修改;2)可以自己写配置文件,配置所有参数。
│   ├── README.md
│   └── mask_rcnn_r50_fpn_albu_1x_coco.py
├── atss
│   ├── README.md
│   ├── atss_r101_fpn_1x_coco.py
│   ├── atss_r50_fpn_1x_coco.py
│   └── metafile.yml
├── autoassign

以下是mmdet(核心代码)目录下的代码说明:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
mmdet
├── __init__.py # 判断配置的mmcv是否符合要求
├── apis # 训练和测试相关依赖的函数,有随机种子生成、训练、单GPU测试、多GPU测试等
│   ├── __init__.py
│   ├── inference.py
│   ├── test.py
│   └── train.py
├── core # 内核代码,包括锚框生成、边界框计算、结果评估、数据结构、掩码生成、可视化、钩子函数等等核心代码
│   ├── __init__.py
│   ├── anchor
│   ├── bbox
│   ├── data_structures
│   ├── evaluation
│   ├── export
│   ├── hook
│   ├── mask
│   ├── optimizers
│   ├── post_processing
│   ├── utils
│   └── visualization
├── datasets # 数据加载器的具体实现,对应configs/datasets
│   ├── __init__.py
│   ├── api_wrappers
│   ├── builder.py
│   ├── cityscapes.py
│   ├── coco.py
│   ├── coco_panoptic.py
│   ├── custom.py
│   ├── dataset_wrappers.py
│   ├── deepfashion.py
│   ├── lvis.py
│   ├── openimages.py
│   ├── pipelines
│   ├── samplers
│   ├── utils.py
│   ├── voc.py
│   ├── wider_face.py
│   └── xml_style.py
├── models # 不同模型的具体实现,分为不同的主干、颈部、头部、损失函数等等,对应 configs/models
│   ├── __init__.py
│   ├── backbones
│   ├── builder.py
│   ├── dense_heads
│   ├── detectors
│   ├── losses
│   ├── necks
│   ├── plugins
│   ├── roi_heads
│   ├── seg_heads
│   └── utils
├── utils # 通用工具
│   ├── __init__.py
│   ├── collect_env.py
│   ├── compat_config.py
│   ├── contextmanagers.py
│   ├── logger.py
│   ├── memory.py
│   ├── misc.py
│   ├── profiling.py
│   ├── replace_cfg_vals.py
│   ├── setup_env.py
│   ├── split_batch.py
│   ├── util_distribution.py
│   ├── util_mixins.py
│   └── util_random.py
└── version.py # 记录mmdetection 的版本

以下是tools目录下的代码说明:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
tools
├── analysis_tools # 分析日志和预测效果
├── dataset_converters # 数据集转换
├── deployment # 部署工具
├── dist_test.sh
├── dist_train.sh
├── misc # 杂项,下载数据集、打印配置信息等工具
├── model_converters
├── slurm_test.sh
├── slurm_train.sh
├── test.py # 测试模型
└── train.py # 根据配置文件进行训练

Faster R-CNN 模型训练

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
data
└── VOCdevkit
    ├── VOC2007
    │   ├── Annotations
    │   ├── ImageSets
    │   ├── JPEGImages
    │   ├── SegmentationClass
    │   └── SegmentationObject
    └── VOC2012
        ├── Annotations
        ├── ImageSets
        ├── JPEGImages
        ├── SegmentationClass
        └── SegmentationObject

根据官方文档,配置相关环境,下载VOC数据集(这里选择相对较小的VOC数据进行训练),并且组织好数据集目录结构,如上所示,然后就可以使用下面命令进行训练:

1
python tools/train.py configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py

笔者训练得到的结果是:mAP 78.8,该结果和官方给的代码结果差将近两个点,可能是训练时候的参数设置不一样导致的。

Faster R-CNN 模型测试

使用下面命令进行测试:

1
python tools/test.py configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py checkpoints/faster_rcnn_r50_fpn_1x_voc0712_20220320_192712-54bef0f3.pth --eval mAP

这里的checkpoints目录下的模型文件是我下载mmdetection训练好的模型文件,可以替换成自己训练好的文件,进行测试。

Faster R-CNN 配置文件解读

上面训练和测试使用的模型配置文件都是configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py,内容如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
_base_ = [
    '../_base_/models/faster_rcnn_r50_fpn.py', '../_base_/datasets/voc0712.py',
    '../_base_/default_runtime.py'
]
model = dict(roi_head=dict(bbox_head=dict(num_classes=20)))  # 修改类别个数
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)  # 不进行梯度截断
# learning policy
# actual epoch = 3 * 3 = 9
lr_config = dict(policy='step', step=[3])
# runtime settings
runner = dict(
    type='EpochBasedRunner', max_epochs=4)  # actual epoch = 4 * 3 = 12

可以看到上面的配置文件,继承自三个配置文件,下面解读configs/_base_/models/faster_rcnn_r50_fpn.py模型配置文件,内容如下:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# model settings
model = dict(
    type='FasterRCNN',
    backbone=dict(
        type='ResNet',  # 骨架网络类名
        depth=50,  # 表示使用 ResNet50
        num_stages=4,  # ResNet 系列包括 stem + 4个 stage 输出
        # 表示本模块输出的特征图索引,(0, 1, 2, 3),表示4个 stage 输出都需要,
        # 其 stride 为 (4,8,16,32),channel 为 (256, 512, 1024, 2048)
        out_indices=(0, 1, 2, 3),
        frozen_stages=1,  # 表示固定 stem 加上第一个 stage 的权重,不进行训练
        norm_cfg=dict(type='BN', requires_grad=True),   # BN 层
        norm_eval=True,  # backbone 所有的 BN 层的均值和方差都直接采用全局预训练值,不进行更新
        style='pytorch',  # 默认采用 pytorch 模式
        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
    neck=dict(
        type='FPN',
        in_channels=[256, 512, 1024, 2048],  # ResNet 模块输出的4个尺度特征图通道数
        out_channels=256,  # FPN 输出的每个尺度输出特征图通道
        num_outs=5),  # FPN 输出特征图个数,将最高层特征图再进行一个pooling操作,产生更高层特征图
    rpn_head=dict(
        type='RPNHead',
        in_channels=256,  # FPN 层输出特征图通道数
        feat_channels=256,  # 中间特征图通道数
        anchor_generator=dict(
            type='AnchorGenerator',
            # 相当于 octave_base_scale,表示每个特征图的 base scales ???
            # FPN 已经有多尺度操作,所以这里尺度只有一个8x8
            scales=[8],
            # 然后通过控制不同的长宽比去产生不同比列的提议框
            ratios=[0.5, 1.0, 2.0],  # 每个特征图有 3 个高宽比例
            strides=[4, 8, 16, 32, 64]),  # 特征图对应的 stride,必须和特征图 stride 一致,不可以随意更改 ???

        bbox_coder=dict(  # 常用! 对边界框回归的目标值进行一个编码
            type='DeltaXYWHBBoxCoder',
            target_means=[.0, .0, .0, .0],
            target_stds=[1.0, 1.0, 1.0, 1.0]),
        loss_cls=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
    roi_head=dict(  # roi_head会基于rpn_head产生的提议框,以及原图的特征图进行预测
        type='StandardRoIHead',
        bbox_roi_extractor=dict(  # 把提议框区域内的特征图从全图特征图中裁剪下来
            type='SingleRoIExtractor',
            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
            out_channels=256,
            featmap_strides=[4, 8, 16, 32]),
        bbox_head=dict(  # 使用上面裁剪下来的特征送入
            type='Shared2FCBBoxHead',  # 2 个共享 FC 模块
            in_channels=256,  # 输入通道数,相等于 FPN 输出通道
            fc_out_channels=1024,  # 中间 FC 层节点个数
            roi_feat_size=7,  # RoIAlign 或 RoIPool 输出的特征图大小
            num_classes=80,   # 类别个数
            bbox_coder=dict(  # bbox 编解码策略???,除了参数外和 RPN 相同
                type='DeltaXYWHBBoxCoder',
                target_means=[0., 0., 0., 0.],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            # 影响 bbox 分支的通道数,True 表示 4 通道输出,False 表示 4×num_classes 通道输出
            reg_class_agnostic=False,
            loss_cls=dict(
                # 80类分类问题,没有使用sigmoid激活函数
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
            loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
    # model training and testing settings
    train_cfg=dict(
        rpn=dict(
            assigner=dict(
                type='MaxIoUAssigner',  # 最大 IoU 原则分配器
                pos_iou_thr=0.7,  # 正样本阈值
                neg_iou_thr=0.3,  # 负样本阈值
                min_pos_iou=0.3,  # 正样本阈值下限 ???
                match_low_quality=True,
                ignore_iof_thr=-1),  # 忽略 bboxes 的阈值,-1 表示不忽略
            sampler=dict(
                type='RandomSampler',  # 随机采样
                num=256,  # 产生1000个框,只采样256个进行训练,采样后每张图片的训练样本总数,不包括忽略样本
                pos_fraction=0.5,  # 正样本比例
                neg_pos_ub=-1,  # 正负样本比例,用于确定负样本采样个数上界
                add_gt_as_proposals=False),  # 是否加入 gt 作为 proposals 以增加高质量正样本数
            allowed_border=-1,
            pos_weight=-1,
            debug=False),
        rpn_proposal=dict(
            nms_pre=2000,  # nms之前
            max_per_img=1000,  # nms之后
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.5,
                neg_iou_thr=0.5,
                min_pos_iou=0.5,
                match_low_quality=False,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=512,
                pos_fraction=0.25,
                neg_pos_ub=-1,
                add_gt_as_proposals=True),
            pos_weight=-1,
            debug=False)),
    test_cfg=dict(
        rpn=dict(
            nms_pre=1000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            score_thr=0.05,
            nms=dict(type='nms', iou_threshold=0.5),
            max_per_img=100)
        # soft-nms is also supported for rcnn testing
        # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
    ))

下面对configs/_base_/datasets/voc0712.py数据集配置文件进行解读,其内容如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
_base_ = [
    '../_base_/models/faster_rcnn_r50_fpn.py', '../_base_/datasets/voc0712.py',
    '../_base_/default_runtime.py'
]
model = dict(roi_head=dict(bbox_head=dict(num_classes=20)))  # 修改类别个数
data = dict(
    samples_per_gpu=8,  # 修改GPU的batch_size,注意不能让其超过显存
    workers_per_gpu=8,  # 修改GPU的workers
)
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)  # 不进行梯度截断
# learning policy
# 有误 actual epoch = 3 * 3 = 9
lr_config = dict(policy='step', step=[8, 11])
# runtime settings
runner = dict(
    type='EpochBasedRunner', max_epochs=12)  # 有误  actual epoch = 4 * 3 = 12

笔者在源码基础上做了一些修改,将samples_per_gpu修改为8workers_per_gpu修改为8lr_config修改为dict(policy='step', step=[8, 11])runner修改为dict(type='EpochBasedRunner', max_epochs=12),注意实际训练时候batch_sizesamples_per_gpu x gpus,我这里训练时候只用了一个GPU,所以batch_size为 8。

训练代码解读

首先从tools里面的train.py入手,官方源码文件train.py

该python文件中main函数其实就做了四件事,这也是训练神经网络通用的步骤:

  1. 设定和读取各种配置;
  2. 创建模型;
  3. 创建数据集;
  4. 将模型,数据集和配置传进训练函数,进行训练;

下面截取tools/train.py中main函数的代码片段进行解读,其内容如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def main():
    # 第一件事
    args = parse_args()

    cfg = Config.fromfile(args.config)

    ...  # 省略部分代码,该部分代码对训练时上下文进行设置和校验

    # 第二件事 创建模型
    # 第一个参数cfg.model
    #      模型配置里面必须要有一个种类type,包括经典的算法如Faster RCNN, MaskRCNN等
    #      其次,还包含几个部分,如backbone, neck, head
    #      backbone有深度,stage等信息,如resnet50对应着3,4,6,3四个重复stages
    #      neck一般FPN(feature pyramid network),需要指定num_outs几个输出之类的信息(之后会看到)
    #      head 就是具体到上层rpn_head, shared_head, bbox_head之类的
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
    model.init_weights()  # 初始化模型参数

    # 第三件事
    datasets = [build_dataset(cfg.data.train)]
    if len(cfg.workflow) == 2:  #是否添加验证集
        val_dataset = copy.deepcopy(cfg.data.val)
        val_dataset.pipeline = cfg.data.train.pipeline
        datasets.append(build_dataset(val_dataset))
    if cfg.checkpoint_config is not None:
        # save mmdet version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(
            mmdet_version=__version__ + get_git_hash()[:7],
            CLASSES=datasets[0].CLASSES)
    # add an attribute for visualization convenience
    model.CLASSES = datasets[0].CLASSES

    # 第四件事
    train_detector(
            model,
            datasets,
            cfg,
            distributed=distributed,
            validate=args.validate,
            logger=logger)

让我们继续阅读build_detector函数的代码片段,其内容如下,该代码在mmdect/models/builder.py中:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

from mmcv.cnn import MODELS as MMCV_MODELS  # mmdet中MODELS继承自该MMCV_MODELS
from mmcv.utils import Registry

MODELS = Registry('models', parent=MMCV_MODELS)

BACKBONES = MODELS
NECKS = MODELS
ROI_EXTRACTORS = MODELS
SHARED_HEADS = MODELS
HEADS = MODELS
LOSSES = MODELS
DETECTORS = MODELS

... # 省略部分代码

def build_detector(cfg, train_cfg=None, test_cfg=None):
    """Build detector."""
    if train_cfg is not None or test_cfg is not None:
        warnings.warn(
            'train_cfg and test_cfg is deprecated, '
            'please specify them in model', UserWarning)
    assert cfg.get('train_cfg') is None or train_cfg is None, \
        'train_cfg specified in both outer field and model field '
    assert cfg.get('test_cfg') is None or test_cfg is None, \
        'test_cfg specified in both outer field and model field '
    return DETECTORS.build(
        cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))

这里采用了Registry方式,即在mmdet/models/builder.py中实例化了一个Registry对象,该对象的父类是MMCV_MODELS,该类是mmcv中的一个类,其作用是用来管理多个模型,其中有一个build函数,该函数的作用是根据模型的配置信息,创建模型,并返回模型对象。

下面代码来自mmcv,版本为1.5.2。

从上面代码可以看到,mmdet中创建模型的函数build_model_from_cfg继承自mmcv/cnn/builder.py中的build_model_from_cfg函数,其内容如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Copyright (c) OpenMMLab. All rights reserved.
from ..runner import Sequential
from ..utils import Registry, build_from_cfg


def build_model_from_cfg(cfg, registry, default_args=None):
    """Build a PyTorch model from config dict(s). Different from
    ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.

    Args:
        cfg (dict, list[dict]): The config of modules, is is either a config
            dict or a list of config dicts. If cfg is a list, a
            the built modules will be wrapped with ``nn.Sequential``.
        registry (:obj:`Registry`): A registry the module belongs to.
        default_args (dict, optional): Default arguments to build the module.
            Defaults to None.

    Returns:
        nn.Module: A built nn module.
    """
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)


MODELS = Registry('model', build_func=build_model_from_cfg)

下面我们在mmcv/utils/registry.py中查看build_from_cfg代码,如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict when it is a class configuration, or
    call a function from config dict when it is a function configuration.

    Example:
        >>> MODELS = Registry('models')
        >>> @MODELS.register_module()
        >>> class ResNet:
        >>>     pass
        >>> resnet = build_from_cfg(dict(type='Resnet'), MODELS)
        >>> # Returns an instantiated object
        >>> @MODELS.register_module()
        >>> def resnet50():
        >>>     pass
        >>> resnet = build_from_cfg(dict(type='resnet50'), MODELS)
        >>> # Return a result of the calling function

    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        registry (:obj:`Registry`): The registry to search the type from.
        default_args (dict, optional): Default initialization arguments.

    Returns:
        object: The constructed object.
    """
    if not isinstance(cfg, dict):
        raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
    if 'type' not in cfg:  # type必须在配置文件里
        if default_args is None or 'type' not in default_args:
            raise KeyError(
                '`cfg` or `default_args` must contain the key "type", '
                f'but got {cfg}\n{default_args}')
    if not isinstance(registry, Registry):
        raise TypeError('registry must be an mmcv.Registry object, '
                        f'but got {type(registry)}')
    if not (isinstance(default_args, dict) or default_args is None):
        raise TypeError('default_args must be a dict or None, '
                        f'but got {type(default_args)}')

    args = cfg.copy()

    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)

    obj_type = args.pop('type')  # 取出type的值
    if isinstance(obj_type, str):
        obj_cls = registry.get(obj_type)  # 从注册器里面取出type,注册过才能用
        if obj_cls is None:
            raise KeyError(
                f'{obj_type} is not in the {registry.name} registry')
    elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError(
            f'type must be a str or valid type, but got {type(obj_type)}')
    try:
        # 根据配置参数实例化模型,就是type定义的那个类(FasterRCNN)的实例化,返回了这个类的对象。
        # 后面会再分析具体模型类(FasterRCNN)的定义。
        return obj_cls(**args)  
    except Exception as e:
        # Normal TypeError does not print class name.
        raise type(e)(f'{obj_cls.__name__}: {e}')

mmcv源码查看,到此为止。

本篇博客分析Faster R-CNN,该算法在mmdet中模型注册type为FasterRCNN,我们找到该类的实现,位置为mmdet/models/detectors/faster_rcnn.py,如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
@DETECTORS.register_module()
class FasterRCNN(TwoStageDetector):
    """Implementation of `Faster R-CNN <https://arxiv.org/abs/1506.01497>`_"""

    def __init__(self,
                 backbone,
                 rpn_head,
                 roi_head,
                 train_cfg,
                 test_cfg,
                 neck=None,
                 pretrained=None,
                 init_cfg=None):
        super(FasterRCNN, self).__init__(
            backbone=backbone,
            neck=neck,
            rpn_head=rpn_head,
            roi_head=roi_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            pretrained=pretrained,
            init_cfg=init_cfg)

那么核心功能肯定都在TwoStageDetector里了。看代码之前首先回顾一下一个经典的双阶段检测Faster R-CNN的流程。

经典双阶段Faster R-CNN流程

下面我们来分析下/mmdet/models/detectors/two_stage.py文件,如下:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
@DETECTORS.register_module()
class TwoStageDetector(BaseDetector):
    """Base class for two-stage detectors.

    Two-stage detectors typically consisting of a region proposal network and a
    task-specific regression head.
    """

    def __init__(self,
                 backbone,
                 neck=None,
                 rpn_head=None,
                 roi_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None,
                 init_cfg=None):
        super(TwoStageDetector, self).__init__(init_cfg)
        if pretrained:
            warnings.warn('DeprecationWarning: pretrained is deprecated, '
                          'please use "init_cfg" instead')
            backbone.pretrained = pretrained
        self.backbone = build_backbone(backbone)

        if neck is not None:
            self.neck = build_neck(neck)

        if rpn_head is not None:
            rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
            rpn_head_ = rpn_head.copy()
            rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)
            self.rpn_head = build_head(rpn_head_)

        if roi_head is not None:
            # update train and test cfg here for now
            # TODO: refactor assigner & sampler
            rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None
            roi_head.update(train_cfg=rcnn_train_cfg)
            roi_head.update(test_cfg=test_cfg.rcnn)
            roi_head.pretrained = pretrained
            self.roi_head = build_head(roi_head)

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

    @property
    def with_rpn(self):
        """bool: whether the detector has RPN"""
        return hasattr(self, 'rpn_head') and self.rpn_head is not None

    @property
    def with_roi_head(self):
        """bool: whether the detector has a RoI head"""
        return hasattr(self, 'roi_head') and self.roi_head is not None

    def extract_feat(self, img):
        """Directly extract features from the backbone+neck."""
        x = self.backbone(img)
        if self.with_neck:
            x = self.neck(x)
        return x

    def forward_dummy(self, img):
        """Used for computing network flops.

        See `mmdetection/tools/analysis_tools/get_flops.py`
        """
        outs = ()
        # backbone
        x = self.extract_feat(img)
        # rpn
        if self.with_rpn:
            rpn_outs = self.rpn_head(x)
            outs = outs + (rpn_outs, )
        proposals = torch.randn(1000, 4).to(img.device)
        # roi_head
        roi_outs = self.roi_head.forward_dummy(x, proposals)
        outs = outs + (roi_outs, )
        return outs

    def forward_train(self,
                      img,
                      img_metas,
                      gt_bboxes,
                      gt_labels,
                      gt_bboxes_ignore=None,
                      gt_masks=None,
                      proposals=None,
                      **kwargs):
        """
        Args:
            img (Tensor): of shape (N, C, H, W) encoding input images.
                Typically these should be mean centered and std scaled.

            img_metas (list[dict]): list of image info dict where each dict
                has: 'img_shape', 'scale_factor', 'flip', and may also contain
                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
                For details on the values of these keys see
                `mmdet/datasets/pipelines/formatting.py:Collect`.

            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.

            gt_labels (list[Tensor]): class indices corresponding to each box

            gt_bboxes_ignore (None | list[Tensor]): specify which bounding
                boxes can be ignored when computing the loss.

            gt_masks (None | Tensor) : true segmentation masks for each box
                used if the architecture supports a segmentation task.

            proposals : override rpn proposals with custom proposals. Use when
                `with_rpn` is False.

        Returns:
            dict[str, Tensor]: a dictionary of loss components
        """
        x = self.extract_feat(img)

        losses = dict()

        # RPN forward and loss
        if self.with_rpn:
            proposal_cfg = self.train_cfg.get('rpn_proposal',
                                              self.test_cfg.rpn)
            rpn_losses, proposal_list = self.rpn_head.forward_train(
                x,
                img_metas,
                gt_bboxes,
                gt_labels=None,
                gt_bboxes_ignore=gt_bboxes_ignore,
                proposal_cfg=proposal_cfg,
                **kwargs)
            losses.update(rpn_losses)
        else:
            proposal_list = proposals

        roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
                                                 gt_bboxes, gt_labels,
                                                 gt_bboxes_ignore, gt_masks,
                                                 **kwargs)
        losses.update(roi_losses)

        return losses
        
... # 省略部分代码

关于上面代码,我画了一个示意图,可以方便理解两阶段检测的架构。 目标检测两阶段架构

可以看到 TwoStageDetector 类的 forward_train 函数比较简单,最难的部分为两个 head ,分别为 RPNHead 和 StandardRoIHead,下面先分析 RPNHead。

该代码位于mmdet/models/dense_heads/rpn_head.py。

参考文章

入门mmdetection(壹)

comments powered by Disqus
Built with Hugo
主题 StackJimmy 设计