diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index b1e4fbf24c..6db7ed667e 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -9,6 +9,7 @@ from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( cast_trt_tensor, + get_axes_for_reduce_op, get_positive_dim, get_trt_tensor, to_numpy, @@ -105,102 +106,30 @@ def layer_norm( cudnn_enable: bool, return_mean_rstd: bool, ) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]: - if weight is None: - weight = to_numpy(1.0) - - if bias is None: - bias = to_numpy(0.0) - - shape = weight.shape - gamma = to_numpy(weight).reshape(shape) - beta = to_numpy(bias).reshape(shape) - - dims = list(range(len(input.shape) - len(shape), len(input.shape))) - - # E[x] - mean_expected_trt = impl.reduce.mean( - ctx, target, source_ir, f"{name}_mean_expected", input, dims, True - ) - - # X-E[x] - sub_trt = impl.elementwise.sub( - ctx, - target, - source_ir, - f"{name}_sub", - input, - mean_expected_trt, - ) - - # Variance = mean(pow(x_sub_mean, 2)) - pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32) - pow_var = impl.elementwise.pow( - ctx, - target, - source_ir, - f"{name}_pow_var", - sub_trt, - pow_trt, - ) - mean_trt = impl.reduce.mean( - ctx, target, source_ir, f"{name}_mean", pow_var, dims, True - ) - - # sqrt((var + eps)) - eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32) - add_trt = impl.elementwise.add( - ctx, - target, - source_ir, - f"{name}_add", - mean_trt, - eps_trt, - ) - sqrt_trt = impl.unary.sqrt( - ctx, - target, - source_ir, - f"{name}_sqrt", - add_trt, - ) - - # (X - E[X]) / sqrt((var + eps)) - div_trt = impl.elementwise.div( - ctx, - target, - source_ir, - f"{name}_div", - sub_trt, - sqrt_trt, - ) - - gamma_trt = get_trt_tensor(ctx, weight, f"{name}_gamma") - beta_trt = get_trt_tensor(ctx, bias, f"{name}_beta") - - # y * gamma + beta - scaled_y = impl.elementwise.mul( - ctx, - target, - source_ir, - f"{name}_mul_gamma", - div_trt, - gamma_trt, - ) + dims = list(range(len(input.shape) - len(normalized_shape), len(input.shape))) + axes = get_axes_for_reduce_op(dims) + + weight = get_trt_tensor(ctx, weight, f"{name}_weight") + bias = get_trt_tensor(ctx, bias, f"{name}_bias") + if tuple(input.shape) != tuple(weight.shape): + weight = impl.slice.expand( + ctx, target, source_ir, f"{name}_expand_weight", weight, input.shape + ) + if tuple(input.shape) != tuple(bias.shape): + bias = impl.slice.expand( + ctx, target, source_ir, f"{name}_expand_bias", bias, input.shape + ) - output = impl.elementwise.add( - ctx, - target, - source_ir, - f"{name}_add_beta", - scaled_y, - beta_trt, - ) + layer_norm = ctx.net.add_normalization(input, weight, bias, axes) + layer_norm.epsilon = eps + layer_norm.compute_precision = input.dtype + set_layer_name(layer_norm, target, f"{name}_layer_norm", source_ir) if return_mean_rstd: # return fake mean and rstd for now - return output, None, None + return layer_norm.get_output(0), None, None - return output + return layer_norm.get_output(0) def native_group_norm( diff --git a/tests/py/dynamo/conversion/test_layer_norm_aten.py b/tests/py/dynamo/conversion/test_layer_norm_aten.py index 7f43234211..c0e055304a 100644 --- a/tests/py/dynamo/conversion/test_layer_norm_aten.py +++ b/tests/py/dynamo/conversion/test_layer_norm_aten.py @@ -1,4 +1,5 @@ import torch +from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input @@ -6,19 +7,31 @@ class TestLayerNormConverter(DispatchTestCase): - def test_layer_norm(self): + @parameterized.expand( + [ + ( + (5, 3, 2, 4), + [ + 4, + ], + ), + ((5, 3, 2, 4), [2, 4]), + ((5, 3, 2, 4), [3, 2, 4]), + ((5, 3, 2, 4), [5, 3, 2, 4]), + ] + ) + def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05): class LayerNorm(torch.nn.Module): def forward(self, x): return torch.ops.aten.layer_norm.default( x, - torch.tensor([3, 224, 224]), - torch.ones((3, 224, 224)), - torch.zeros((3, 224, 224)), - 1e-05, - True, + normalized_shape, + torch.randn(normalized_shape), + torch.randn(normalized_shape), + eps, ) - inputs = [torch.randn(1, 3, 224, 224)] + inputs = [torch.randn(input_shape)] self.run_test( LayerNorm(), inputs, @@ -26,7 +39,37 @@ def forward(self, x): class TestNativeLayerNormConverter(DispatchTestCase): - def test_layer_norm(self): + @parameterized.expand( + [ + ( + (5, 3, 2, 4), + [ + 4, + ], + ), + ((5, 3, 2, 4), [2, 4]), + ((5, 3, 2, 4), [3, 2, 4]), + ((5, 3, 2, 4), [5, 3, 2, 4]), + ] + ) + def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05): + class LayerNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.native_layer_norm.default( + x, + normalized_shape, + torch.randn(normalized_shape), + torch.randn(normalized_shape), + eps, + )[0] + + inputs = [torch.randn(input_shape)] + self.run_test( + LayerNorm(), + inputs, + ) + + def test_layernorm_with_dynamic_shape(self): class LayerNorm(torch.nn.Module): def forward(self, x): return torch.ops.aten.native_layer_norm.default( @@ -37,10 +80,17 @@ def forward(self, x): 1e-05, )[0] - inputs = [torch.randn(1, 3, 224, 224)] - self.run_test( + input_specs = [ + Input( + shape=(-1, 3, 224, 224), + dtype=torch.float32, + shape_ranges=[((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))], + ), + ] + + self.run_test_with_dynamic_shape( LayerNorm(), - inputs, + input_specs, ) diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 457e9e2e81..533e9d84d3 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -475,14 +475,13 @@ def forward(self, x): optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = fx_graph(*inputs).detach().cpu() - max_diff = float( - torch.max(torch.abs(optimized_model_results - torch_model_results)) - ) - self.assertAlmostEqual( - max_diff, - 0, - DECIMALS_OF_AGREEMENT, - f"Select_scatter TRT outputs don't match with the original model.", + optimized_model_results_shape = optimized_model_results.size() + torch_model_results_shape = torch_model_results.size() + + self.assertEquals( + optimized_model_results_shape, + torch_model_results_shape, + f"The optimized model results shape and torch model results shape should be equal in empty_like", )