Pytorch源码阅读(三):BatchNorm Module

Module

nn.BNReLU2d

源码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class BNReLU2d(nnq.BatchNorm2d):
    r"""
    A BNReLU2d module is a fused module of BatchNorm2d and ReLU

    We adopt the same interface as :class:`torch.nn.quantized.BatchNorm2d`.

    Attributes:
        Same as torch.nn.quantized.BatchNorm2d

    """
    _FLOAT_MODULE = torch.nn.intrinsic.BNReLU2d

    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(BNReLU2d, self).__init__(num_features, eps=eps, momentum=momentum)

    def forward(self, input):
        # Temporarily using len(shape) instead of ndim due to JIT issue
        # https://github.com/pytorch/pytorch/issues/23890
        if len(input.shape) != 4:
            raise ValueError("Input shape must be `(N, C, H, W)`!")
        return torch.ops.quantized.batch_norm2d_relu(
            input, self.weight, self.bias, self.running_mean,
            self.running_var, self.eps, self.scale, self.zero_point)

    def _get_name(self):
        return 'QuantizedBNReLU2d'

    @classmethod
    def from_float(cls, mod):
        # TODO: Add qat support for BNReLU2d
        return super(BNReLU2d, cls).from_float(mod))

nn.BatchNorm2d

1
2
3
4
5
class BatchNorm2d(_BatchNorm):
    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(input.dim()))

torch.nn.BatchNorm2类如上面源码所示,该类具体实现部分基本上都来自继承_BatchNorm这个内部类,而后者类里面具体实现不是用python实现的,而是用c++、cuda实现的,这里我就不具体分析底层源码了。

下面分析这个类功能:

这篇论文中Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift提出了BN层。首先该类的输入为小批量带通道的二维输入,也就是输入的大小为(N,C,H,W)。对输入的数据做如下公式的变换: $$ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta $$ 平均值和标准差是按小批量的每个维度计算,$\gamma$和$\beta$为通道维度上的可学习的参数向量,默认情况下,$\gamma$为1,$\beta$为0。

comments powered by Disqus
Built with Hugo
主题 StackJimmy 设计