Skip to content

Commit 6805e39

Browse files
committed
Fix LayerNorm fp16 precision
1 parent 8e2c82d commit 6805e39

File tree

3 files changed

+48
-108
lines changed

3 files changed

+48
-108
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,6 @@ def aten_ops_batch_norm_legit_no_training(
134134
capability_validator=one_user_validator,
135135
supports_dynamic_shapes=True,
136136
)
137-
@dynamo_tensorrt_converter(
138-
torch.ops.aten.layer_norm.default, supports_dynamic_shapes=True
139-
)
140-
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm, supports_dynamic_shapes=True)
141137
@enforce_tensor_types(
142138
{
143139
0: (TRTTensor,),
@@ -157,11 +153,9 @@ def aten_ops_layer_norm(
157153
name,
158154
input=args[0],
159155
normalized_shape=args[1],
160-
weight=args_bounds_check(args, 2, 1.0),
161-
bias=args_bounds_check(args, 3, 0.0),
162-
eps=args_bounds_check(args, 4, 1e-05),
163-
cudnn_enable=args_bounds_check(args, 5, True),
164-
return_mean_rstd=(target == torch.ops.aten.native_layer_norm.default),
156+
weight=args_bounds_check(args, 2),
157+
bias=args_bounds_check(args, 3),
158+
eps=args[4],
165159
)
166160

167161

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

+16-23
Original file line numberDiff line numberDiff line change
@@ -159,49 +159,42 @@ def layer_norm(
159159
name: str,
160160
input: TRTTensor,
161161
normalized_shape: List[int],
162-
weight: Optional[Union[torch.Tensor, np.ndarray]],
163-
bias: Optional[Union[torch.Tensor, np.ndarray]],
162+
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
163+
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
164164
eps: float,
165-
cudnn_enable: bool,
166-
return_mean_rstd: bool,
167-
) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]:
165+
) -> Tuple[TRTTensor, torch.Tensor, torch.Tensor]:
168166
dims = list(range(len(input.shape) - len(normalized_shape), len(input.shape)))
169167
axes = get_axes_for_reduce_op(dims)
170-
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
171-
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
168+
169+
weight = get_trt_tensor(
170+
ctx, weight if weight is not None else 1.0, f"{name}_weight"
171+
)
172+
bias = get_trt_tensor(ctx, bias if bias is not None else 0.0, f"{name}_bias")
173+
172174
# Cast weight and bias to have same dtype as input
173175
weight = cast_trt_tensor(
174176
ctx, weight, input.dtype, f"{name}_weight_cast", target, source_ir
175177
)
176178
bias = cast_trt_tensor(
177179
ctx, bias, input.dtype, f"{name}_bias_cast", target, source_ir
178180
)
181+
179182
if tuple(input.shape) != tuple(weight.shape):
180183
weight = impl.slice.expand(
181184
ctx, target, source_ir, f"{name}_expand_weight", weight, input.shape
182185
)
186+
183187
if tuple(input.shape) != tuple(bias.shape):
184188
bias = impl.slice.expand(
185189
ctx, target, source_ir, f"{name}_expand_bias", bias, input.shape
186190
)
187-
strongly_typed_network = False
188-
if ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED):
189-
weight = cast_trt_tensor(ctx, weight, input.dtype, name)
190-
bias = cast_trt_tensor(ctx, bias, input.dtype, name)
191-
strongly_typed_network = True
192-
193-
layer_norm = ctx.net.add_normalization(input, weight, bias, axes)
194-
layer_norm.epsilon = eps
195-
# compute_precision ignored for strongly typed network.
196-
if not strongly_typed_network:
197-
layer_norm.compute_precision = input.dtype
198-
set_layer_name(layer_norm, target, f"{name}_layer_norm", source_ir)
199191

200-
if return_mean_rstd:
201-
# return fake mean and rstd for now
202-
return layer_norm.get_output(0), None, None
192+
layer = ctx.net.add_normalization(input, weight, bias, axes)
193+
layer.epsilon = eps
194+
set_layer_name(layer, target, name, source_ir)
203195

204-
return layer_norm.get_output(0)
196+
# return fake mean and rstd for now
197+
return layer.get_output(0), None, None
205198

206199

207200
def native_group_norm(

tests/py/dynamo/conversion/test_layer_norm_aten.py

+29-76
Original file line numberDiff line numberDiff line change
@@ -6,78 +6,51 @@
66
from .harness import DispatchTestCase
77

88

9-
class TestLayerNormConverter(DispatchTestCase):
9+
class TestNativeLayerNormConverter(DispatchTestCase):
1010
@parameterized.expand(
1111
[
12-
(
13-
(5, 3, 2, 4),
14-
[
15-
4,
16-
],
17-
),
18-
((5, 3, 2, 4), [2, 4]),
19-
((5, 3, 2, 4), [3, 2, 4]),
20-
((5, 3, 2, 4), [5, 3, 2, 4]),
12+
((2, 4, 6), [6]),
13+
((2, 4, 6), [4, 6]),
14+
((2, 4, 6), [2, 4, 6]),
2115
]
2216
)
23-
def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05):
17+
def test_layer_norm_1d(self, input_shape, normalized_shape):
2418
class LayerNorm(torch.nn.Module):
2519
def forward(self, x):
26-
return torch.ops.aten.layer_norm.default(
27-
x,
28-
normalized_shape,
29-
torch.randn(normalized_shape),
30-
torch.randn(normalized_shape),
31-
eps,
32-
)
20+
return torch.ops.aten.native_layer_norm.default(
21+
x, normalized_shape, None, None, 1e-05
22+
)[0]
3323

3424
inputs = [torch.randn(input_shape)]
35-
self.run_test(
36-
LayerNorm(),
37-
inputs,
38-
)
39-
25+
self.run_test(LayerNorm(), inputs, use_dynamo_tracer=True)
4026

41-
class TestNativeLayerNormConverter(DispatchTestCase):
4227
@parameterized.expand(
4328
[
44-
(
45-
(5, 3, 2, 4),
46-
[
47-
4,
48-
],
49-
),
29+
((5, 3, 2, 4), [4]),
5030
((5, 3, 2, 4), [2, 4]),
5131
((5, 3, 2, 4), [3, 2, 4]),
5232
((5, 3, 2, 4), [5, 3, 2, 4]),
5333
]
5434
)
55-
def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05):
35+
def test_layer_norm_2d(self, input_shape, normalized_shape):
5636
class LayerNorm(torch.nn.Module):
57-
def forward(self, x):
37+
def forward(self, x, weight, bias):
5838
return torch.ops.aten.native_layer_norm.default(
59-
x,
60-
normalized_shape,
61-
torch.randn(normalized_shape),
62-
torch.randn(normalized_shape),
63-
eps,
39+
x, normalized_shape, weight, bias, 1e-05
6440
)[0]
6541

66-
inputs = [torch.randn(input_shape)]
67-
self.run_test(
68-
LayerNorm(),
69-
inputs,
70-
)
42+
inputs = [
43+
torch.randn(input_shape),
44+
torch.randn(normalized_shape),
45+
torch.randn(normalized_shape),
46+
]
47+
self.run_test(LayerNorm(), inputs, use_dynamo_tracer=True)
7148

7249
def test_layernorm_with_dynamic_shape(self):
7350
class LayerNorm(torch.nn.Module):
74-
def forward(self, x):
51+
def forward(self, x, weight, bias):
7552
return torch.ops.aten.native_layer_norm.default(
76-
x,
77-
torch.tensor([3, 224, 224]),
78-
torch.ones((3, 224, 224)),
79-
torch.zeros((3, 224, 224)),
80-
1e-05,
53+
x, [3, 224, 224], weight, bias, 1e-05
8154
)[0]
8255

8356
input_specs = [
@@ -87,22 +60,19 @@ def forward(self, x):
8760
opt_shape=(5, 3, 224, 224),
8861
max_shape=(10, 3, 224, 224),
8962
),
63+
Input(dtype=torch.float32, shape=(3, 224, 224)),
64+
Input(dtype=torch.float32, shape=(3, 224, 224)),
9065
]
9166

9267
self.run_test_with_dynamic_shape(
93-
LayerNorm(),
94-
input_specs,
68+
LayerNorm(), input_specs, use_dynamo_tracer=True
9569
)
9670

9771
def test_layernorm_with_dynamic_shape_1(self):
9872
class LayerNorm(torch.nn.Module):
99-
def forward(self, x):
73+
def forward(self, x, weight, bias):
10074
return torch.ops.aten.native_layer_norm.default(
101-
x,
102-
torch.tensor([3]),
103-
torch.ones((3)),
104-
torch.zeros((3)),
105-
1e-05,
75+
x, [3], weight, bias, 1e-05
10676
)[0]
10777

10878
input_specs = [
@@ -112,29 +82,12 @@ def forward(self, x):
11282
opt_shape=(3, 3, 3),
11383
max_shape=(4, 5, 3),
11484
),
85+
Input(dtype=torch.float32, shape=(3,)),
86+
Input(dtype=torch.float32, shape=(3,)),
11587
]
11688

11789
self.run_test_with_dynamic_shape(
118-
LayerNorm(),
119-
input_specs,
120-
)
121-
122-
@parameterized.expand([((5, 3, 2, 4), [2, 4])])
123-
def test_layer_norm_without_Scaling(self, input_shape, normalized_shape, eps=1e-05):
124-
class LayerNorm(torch.nn.Module):
125-
def forward(self, x):
126-
return torch.ops.aten.native_layer_norm.default(
127-
x,
128-
normalized_shape,
129-
None,
130-
None,
131-
eps,
132-
)[0]
133-
134-
inputs = [torch.randn(input_shape)]
135-
self.run_test(
136-
LayerNorm(),
137-
inputs,
90+
LayerNorm(), input_specs, use_dynamo_tracer=True
13891
)
13992

14093

0 commit comments

Comments
 (0)