批量规范化(batch normalization)

  • (Ioffe and Szegedy, 2015),这是一种流行且有效的技术,可持续加速深层网络的收敛速度。
  • 加速深层网络的收敛速度,能够训练更深的网络
  • 从数据处理的灵感来的,使用标准化规范数据(均值为 0,方差为 1),进行正则化
    • 统一参数的数量级
      • 不同层的输出量级不一样,学习率不能适应所有网络层
    • 避免过拟合
  • 方式
    • 每个 batch 中计算均值和标准差,使用 $\frac{x-\mu}{\sigma}$
      • batch 要足够大才能有效和稳定

公式:
$\mathrm{BN}(\mathbf{x}) = \boldsymbol{\gamma} \odot \frac{\mathbf{x} - \hat{\boldsymbol{\mu}}\mathcal{B}}{\hat{\boldsymbol{\sigma}}\mathcal{B}} + \boldsymbol{\beta}.$

  • 拉伸参数(scale)$\boldsymbol{\gamma}$ 和偏移参数(shift)$\boldsymbol{\beta}$
    • 是学习的参数,否则这样均值和方差都是固定的

计算:

$$ \hat{\boldsymbol{\sigma}}\mathcal{B}^2 &= \frac{1}{|\mathcal{B}|} \sum{\mathbf{x} \in \mathcal{B}} (\mathbf{x} - \hat{\boldsymbol{\mu}}_{\mathcal{B}})^2 + \epsilon.\end{aligned}\end{split}$$

  • $\epsilon$ 避免除 0

训练预测的区别:

  • 训练:只能知道小批量的均值、方差

    • 更新移动的均值和方差,预测时用
    • $\mu_t = m * \mu_{t-1} + (1-m) * \mu, \sigma_t = m * \sigma_{t-1} + (1-m) * \sigma$
      • m 是动量,训练批次每次累积均值的偏移
  • 预测:预测时可以知道所有数据的均值、方差

    • 所以预测时要提供均值和方差
  • 全连接层(nn.BatchNorm2d):$\mathbf{h} = \phi(\mathrm{BN}(\mathbf{W}\mathbf{x} + \mathbf{b}) ).$

  • 卷积的算法会有所不同(和通道相关,nn.BatchNorm4d):略

    • 计算各通道的均值和方差

$$