1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
| import torch import torch.nn as nn
def group_norm(x: torch.Tensor, num_groups: int, num_channels: int, eps: float = 1e-5, gamma: float = 1.0, beta: float = 0.): assert divmod(num_channels, num_groups)[1] == 0 channels_per_group = num_channels // num_groups
new_tensor = [] for t in x.split(channels_per_group, dim=1): var_mean = torch.var_mean(t, dim=[1, 2, 3], unbiased=False) var = var_mean[0] mean = var_mean[1] t = (t - mean[:, None, None, None]) / torch.sqrt(var[:, None, None, None] + eps) t = t * gamma + beta new_tensor.append(t)
new_tensor = torch.cat(new_tensor, dim=1) return new_tensor
def main(): num_groups = 2 num_channels = 4 eps = 1e-5
img = torch.rand(2, num_channels, 2, 2) print(img)
gn = nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=eps) r1 = gn(img) print(r1)
r2 = group_norm(img, num_groups, num_channels, eps) print(r2)
if __name__ == '__main__': main()
|