Batch vs. Layer vs. Group Normalization

Batch Normalization

  • 在图像预处理过程中通常会对图像进行标准化处理,这样能加速网络的收敛。

$$
\text{image} \overset{\text{preprocess}}{\rightarrow} \text{Conv1} -> \text{feature map} -> \text{Conv2}
$$

  • 上述过程中,对于 Conv1 的输入就是满足某一分布的特征矩阵,但对于 Conv2 而言输入的 feature map 就不一定满足某一分布规律了(注意这里所说满足某一分布规律并不是指某一个 feature map 的数据要满足分布规律,理论上是指整个训练样本集所对应 feature map 的数据要满足分布规律)。
  • Batch Normalization 的目的就是使 feature map 满足均值为 0,方差为 1 的分布规律。
  • 让 feature map 满足某一分布规律,理论上是指整个训练样本集所对应 feature map 的数据要满足分布规律,即计算整个训练集的 feature map 然后再进行标准化处理,对于大型数据集显然不可能,所以需要使用 Batch Normalization,计算一个 Batch 数据的 feature map 然后再标准化(batch 越大越接近整个数据集的分布)。
  1. 训练时要将 traning 参数设置为 True,在验证时将 trainning 参数设置为 False。在pytorch中可通过创建模型的 model.train() 和 model.eval() 方法控制。
  2. batch size 尽可能设置大点,设置小后表现可能很糟糕,设置的越大求的均值和方差越接近整个训练集的均值和方差。
  3. 建议将 bn 层放在卷积层(Conv)和激活层(例如Relu)之间,且卷积层不要使用偏置 bias,即使使用了偏置 bias 求出的结果也是一样。
Batch Normalization
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import numpy as np
import torch.nn as nn
import torch

def bn_process(feature, mean, var):
feature_shape = feature.shape
for i in range(feature_shape[1]):
# [batch, channel, height, width]
feature_t = feature[:, i, :, :]
mean_t = feature_t.mean()
# 总体标准差
std_t1 = feature_t.std()
# 样本标准差
std_t2 = feature_t.std(ddof=1)

# bn process
# 这里记得加上eps和pytorch保持一致
feature[:, i, :, :] = (feature[:, i, :, :] - mean_t) / np.sqrt(std_t1 ** 2 + 1e-5)
# update calculating mean and var
mean[i] = mean[i] * 0.9 + mean_t * 0.1
var[i] = var[i] * 0.9 + (std_t2 ** 2) * 0.1
print(feature)


# 随机生成一个batch为2,channel为4,height=width=2的特征向量
# [batch, channel, height, width]
feature1 = torch.randn(2, 4, 2, 2)
# 初始化统计均值和方差
calculate_mean = [0.0, 0.0]
calculate_var = [1.0, 1.0]
# print(feature1.numpy())

# 注意要使用copy()深拷贝
bn_process(feature1.numpy().copy(), calculate_mean, calculate_var)

bn = nn.BatchNorm2d(2, eps=1e-5)
output = bn(feature1)
print(output)

Layer Normalization

  • Layer Normalization 针对 NLP,例如 RNN,不使用 Batch Normalization 的原因是:在 RNN 这类时序网络中,时序的长度不是一个定值,比如每句话的长短不同,很难使用 BN,所以需要使用 LN。(但 ViT 还是会涉及到 LN)
  • Layer Normalization 与 Batch Normalization 的区别在于:BN 是对于一个 batch 数据的每个 channel 进行 Norm,但 LN 是对单个数据的指定维度进行 Norm 处理,与 batch 无关
  • LN 指定要 Norm 的维度必须从最后一维开始。
Layer Normalization
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn as nn

def layer_norm_process(feature: torch.Tensor, beta=0., gamma=1., eps=1e-5):
var_mean = torch.var_mean(feature, dim=[1, 2], unbiased=False)
# 均值
mean = var_mean[1]
# 方差
var = var_mean[0]

# layer norm process
feature = (feature - mean[..., None]) / torch.sqrt(var[..., None] + eps)
feature = feature * gamma + beta

return feature

def main():
t = torch.rand(4, 2, 3)
print(t)
# 仅在最后一个维度上做norm处理
norm = nn.LayerNorm(normalized_shape=t.shape[-1], eps=1e-5)
# 官方layer norm处理
t1 = norm(t)
# 自己实现的layer norm处理
t2 = layer_norm_process(t, eps=1e-5)
print("t1:\n", t1)
print("t2:\n", t2)

if __name__ == '__main__':
main()

Group Normalization

  • 最常用的 BN 有一个缺点,Batch Size 通常较大,当 batch size 小于 16 后 error 明显升高,对于大型网络或 GPU 显存不够的情况下,可以使用 Group Normalization。
  • batch size 的大小对 GN 并没有影响,当 batch size 设置较小时,可以采用 GN。
  • 对于 GN,假设 num_groups = 2(原论文默认为 32),假设某层的输出得到 x,根据 num_groups 沿 channel 方向均分成 num_groups 份,然后对每一份求均值和方差
Group Normalization
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn as nn

def group_norm(x: torch.Tensor,
num_groups: int,
num_channels: int,
eps: float = 1e-5,
gamma: float = 1.0,
beta: float = 0.):
assert divmod(num_channels, num_groups)[1] == 0
channels_per_group = num_channels // num_groups

new_tensor = []
for t in x.split(channels_per_group, dim=1):
var_mean = torch.var_mean(t, dim=[1, 2, 3], unbiased=False)
var = var_mean[0]
mean = var_mean[1]
t = (t - mean[:, None, None, None]) / torch.sqrt(var[:, None, None, None] + eps)
t = t * gamma + beta
new_tensor.append(t)

new_tensor = torch.cat(new_tensor, dim=1)
return new_tensor

def main():
num_groups = 2
num_channels = 4
eps = 1e-5

img = torch.rand(2, num_channels, 2, 2)
print(img)

gn = nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=eps)
r1 = gn(img)
print(r1)

r2 = group_norm(img, num_groups, num_channels, eps)
print(r2)

if __name__ == '__main__':
main()

References

本文作者:jujimeizuo
本文地址https://blog.jujimeizuo.cn/2025/03/17/BN-LN-GN/
本博客所有文章除特别声明外,均采用 CC BY-SA 3.0 协议。转载请注明出处!