Skip to content

Commit 8148866

Browse files
committed
Fix LayerNorm fp16 precision
1 parent 8e2c82d commit 8148866

File tree

3 files changed

+28
-78
lines changed

3 files changed

+28
-78
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,6 @@ def aten_ops_batch_norm_legit_no_training(
134134
capability_validator=one_user_validator,
135135
supports_dynamic_shapes=True,
136136
)
137-
@dynamo_tensorrt_converter(
138-
torch.ops.aten.layer_norm.default, supports_dynamic_shapes=True
139-
)
140-
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm, supports_dynamic_shapes=True)
141137
@enforce_tensor_types(
142138
{
143139
0: (TRTTensor,),
@@ -157,11 +153,9 @@ def aten_ops_layer_norm(
157153
name,
158154
input=args[0],
159155
normalized_shape=args[1],
160-
weight=args_bounds_check(args, 2, 1.0),
161-
bias=args_bounds_check(args, 3, 0.0),
162-
eps=args_bounds_check(args, 4, 1e-05),
163-
cudnn_enable=args_bounds_check(args, 5, True),
164-
return_mean_rstd=(target == torch.ops.aten.native_layer_norm.default),
156+
weight=args[2],
157+
bias=args[3],
158+
eps=args[4],
165159
)
166160

167161

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

+16-23
Original file line numberDiff line numberDiff line change
@@ -159,49 +159,42 @@ def layer_norm(
159159
name: str,
160160
input: TRTTensor,
161161
normalized_shape: List[int],
162-
weight: Optional[Union[torch.Tensor, np.ndarray]],
163-
bias: Optional[Union[torch.Tensor, np.ndarray]],
162+
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
163+
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
164164
eps: float,
165-
cudnn_enable: bool,
166-
return_mean_rstd: bool,
167-
) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]:
165+
) -> Tuple[TRTTensor, torch.Tensor, torch.Tensor]:
168166
dims = list(range(len(input.shape) - len(normalized_shape), len(input.shape)))
169167
axes = get_axes_for_reduce_op(dims)
170-
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
171-
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
168+
169+
weight = get_trt_tensor(
170+
ctx, weight if weight is not None else 1.0, f"{name}_weight"
171+
)
172+
bias = get_trt_tensor(ctx, bias if bias is not None else 0.0, f"{name}_bias")
173+
172174
# Cast weight and bias to have same dtype as input
173175
weight = cast_trt_tensor(
174176
ctx, weight, input.dtype, f"{name}_weight_cast", target, source_ir
175177
)
176178
bias = cast_trt_tensor(
177179
ctx, bias, input.dtype, f"{name}_bias_cast", target, source_ir
178180
)
181+
179182
if tuple(input.shape) != tuple(weight.shape):
180183
weight = impl.slice.expand(
181184
ctx, target, source_ir, f"{name}_expand_weight", weight, input.shape
182185
)
186+
183187
if tuple(input.shape) != tuple(bias.shape):
184188
bias = impl.slice.expand(
185189
ctx, target, source_ir, f"{name}_expand_bias", bias, input.shape
186190
)
187-
strongly_typed_network = False
188-
if ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED):
189-
weight = cast_trt_tensor(ctx, weight, input.dtype, name)
190-
bias = cast_trt_tensor(ctx, bias, input.dtype, name)
191-
strongly_typed_network = True
192-
193-
layer_norm = ctx.net.add_normalization(input, weight, bias, axes)
194-
layer_norm.epsilon = eps
195-
# compute_precision ignored for strongly typed network.
196-
if not strongly_typed_network:
197-
layer_norm.compute_precision = input.dtype
198-
set_layer_name(layer_norm, target, f"{name}_layer_norm", source_ir)
199191

200-
if return_mean_rstd:
201-
# return fake mean and rstd for now
202-
return layer_norm.get_output(0), None, None
192+
layer = ctx.net.add_normalization(input, weight, bias, axes)
193+
layer.epsilon = eps
194+
set_layer_name(layer, target, name, source_ir)
203195

204-
return layer_norm.get_output(0)
196+
# return fake mean and rstd for now
197+
return layer.get_output(0), None, None
205198

206199

207200
def native_group_norm(

tests/py/dynamo/conversion/test_layer_norm_aten.py

+9-46
Original file line numberDiff line numberDiff line change
@@ -6,61 +6,24 @@
66
from .harness import DispatchTestCase
77

88

9-
class TestLayerNormConverter(DispatchTestCase):
10-
@parameterized.expand(
11-
[
12-
(
13-
(5, 3, 2, 4),
14-
[
15-
4,
16-
],
17-
),
18-
((5, 3, 2, 4), [2, 4]),
19-
((5, 3, 2, 4), [3, 2, 4]),
20-
((5, 3, 2, 4), [5, 3, 2, 4]),
21-
]
22-
)
23-
def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05):
24-
class LayerNorm(torch.nn.Module):
25-
def forward(self, x):
26-
return torch.ops.aten.layer_norm.default(
27-
x,
28-
normalized_shape,
29-
torch.randn(normalized_shape),
30-
torch.randn(normalized_shape),
31-
eps,
32-
)
33-
34-
inputs = [torch.randn(input_shape)]
35-
self.run_test(
36-
LayerNorm(),
37-
inputs,
38-
)
39-
40-
419
class TestNativeLayerNormConverter(DispatchTestCase):
4210
@parameterized.expand(
4311
[
44-
(
45-
(5, 3, 2, 4),
46-
[
47-
4,
48-
],
49-
),
12+
((5, 3, 2, 4), [4]),
5013
((5, 3, 2, 4), [2, 4]),
5114
((5, 3, 2, 4), [3, 2, 4]),
5215
((5, 3, 2, 4), [5, 3, 2, 4]),
5316
]
5417
)
55-
def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05):
18+
def test_layer_norm(self, input_shape, normalized_shape):
5619
class LayerNorm(torch.nn.Module):
5720
def forward(self, x):
5821
return torch.ops.aten.native_layer_norm.default(
5922
x,
6023
normalized_shape,
6124
torch.randn(normalized_shape),
6225
torch.randn(normalized_shape),
63-
eps,
26+
1e-05,
6427
)[0]
6528

6629
inputs = [torch.randn(input_shape)]
@@ -74,7 +37,7 @@ class LayerNorm(torch.nn.Module):
7437
def forward(self, x):
7538
return torch.ops.aten.native_layer_norm.default(
7639
x,
77-
torch.tensor([3, 224, 224]),
40+
[3, 224, 224],
7841
torch.ones((3, 224, 224)),
7942
torch.zeros((3, 224, 224)),
8043
1e-05,
@@ -99,9 +62,9 @@ class LayerNorm(torch.nn.Module):
9962
def forward(self, x):
10063
return torch.ops.aten.native_layer_norm.default(
10164
x,
102-
torch.tensor([3]),
103-
torch.ones((3)),
104-
torch.zeros((3)),
65+
[3],
66+
torch.randn((3)),
67+
torch.randn((3)),
10568
1e-05,
10669
)[0]
10770

@@ -120,15 +83,15 @@ def forward(self, x):
12083
)
12184

12285
@parameterized.expand([((5, 3, 2, 4), [2, 4])])
123-
def test_layer_norm_without_Scaling(self, input_shape, normalized_shape, eps=1e-05):
86+
def test_layer_norm_without_Scaling(self, input_shape, normalized_shape):
12487
class LayerNorm(torch.nn.Module):
12588
def forward(self, x):
12689
return torch.ops.aten.native_layer_norm.default(
12790
x,
12891
normalized_shape,
12992
None,
13093
None,
131-
eps,
94+
1e-05,
13295
)[0]
13396

13497
inputs = [torch.randn(input_shape)]

0 commit comments

Comments
 (0)