Skip to content
This repository was archived by the owner on Jan 26, 2022. It is now read-only.

Commit 2a5facd

Browse files
committed
Fix group norm layer
1 parent 32cfbc6 commit 2a5facd

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

lib/nn/functional.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@
22

33

44
def group_norm(x, num_groups, weight=None, bias=None, eps=1e-5):
5-
N, C, H, W = x.size()
6-
assert C % num_groups == 0, "input channel dimension must divisible by number of groups"
7-
x = x.view(N, num_groups, -1)
5+
input_shape = x.shape
6+
ndim = len(input_shape)
7+
N, C = input_shape[:2]
8+
G = num_groups
9+
assert C % G == 0, "input channel dimension must divisible by number of groups"
10+
x = x.view(N, G, -1)
811
mean = x.mean(-1, keepdim=True)
912
var = x.var(-1, keepdim=True)
1013
x = (x - mean) / (var + eps).sqrt()
11-
x = x.view(N, C, H, W)
12-
if weight is not None: # affine=True
13-
return x * weight + bias
14-
return x
14+
x = x.view(input_shape)
15+
view_shape = (1, -1) + (1,) * (ndim - 2)
16+
if weight is not None:
17+
return x * weight.view(view_shape) + bias.view(view_shape)
18+
return x

lib/nn/modules/normalization.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
class GroupNorm(nn.Module):
1010
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
11-
super(GroupNorm, self).__init__()
11+
super().__init__()
1212
self.num_groups = num_groups
1313
self.num_channels = num_channels
1414
self.eps = eps
@@ -28,7 +28,7 @@ def reset_parameters(self):
2828

2929
def forward(self, x):
3030
return myF.group_norm(
31-
x, self.num_groups, self.weight, self.bias, self.eps)
31+
x, self.num_groups, self.weight, self.bias, self.eps
3232
)
3333

3434
def extra_repr(self):

0 commit comments

Comments
 (0)