From 00ddd557d7e4ecaccab0a077b151f1c061637c3f Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Fri, 18 Apr 2025 21:04:56 +0000 Subject: [PATCH] test: support automatic plugin feature with different dimensions and add flashinfer.rmsnorm support test case --- .github/workflows/build-test-linux.yml | 1 + py/requirements.txt | 2 +- .../automatic_plugin/test_automatic_plugin.py | 9 ---- .../test_flashinfer_rmsnorm.py | 52 +++++++++++++++++++ tests/py/requirements.txt | 1 + 5 files changed, 55 insertions(+), 10 deletions(-) create mode 100644 tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py diff --git a/.github/workflows/build-test-linux.yml b/.github/workflows/build-test-linux.yml index 4d252b24e4..d0fabf9993 100644 --- a/.github/workflows/build-test-linux.yml +++ b/.github/workflows/build-test-linux.yml @@ -142,6 +142,7 @@ jobs: python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/ python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_automatic_plugin.py python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_automatic_plugin_with_attrs.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_flashinfer_rmsnorm.py popd tests-py-dynamo-fe: diff --git a/py/requirements.txt b/py/requirements.txt index 696700398d..dbb342dae5 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -5,4 +5,4 @@ pybind11==2.6.2 torch>=2.8.0.dev,<2.9.0 torchvision>=0.22.0.dev,<0.23.0 --extra-index-url https://pypi.ngc.nvidia.com -pyyaml +pyyaml \ No newline at end of file diff --git a/tests/py/dynamo/automatic_plugin/test_automatic_plugin.py b/tests/py/dynamo/automatic_plugin/test_automatic_plugin.py index ae60f8cda7..8ab47def08 100644 --- a/tests/py/dynamo/automatic_plugin/test_automatic_plugin.py +++ b/tests/py/dynamo/automatic_plugin/test_automatic_plugin.py @@ -81,12 +81,3 @@ def forward(self, lhs, rhs): if __name__ == "__main__": run_tests() - -# Example Usage -# A = torch.full((64, 64), 2, device="cuda", dtype=torch.float) -# B = torch.full((64, 64), 3, device="cuda", dtype=torch.float) - -# C, D = torch.ops.torchtrt_ex.elementwise_add_mul.default(A, B) - -# print("C (Addition):", C) -# print("D (Multiplication):", D) diff --git a/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py b/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py new file mode 100644 index 0000000000..6c10aafb7a --- /dev/null +++ b/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py @@ -0,0 +1,52 @@ +import pytest + +flashinfer = pytest.importorskip("flashinfer") +import torch +import torch.nn as nn +import torch_tensorrt +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt._enums import dtype + +from ..conversion.harness import DispatchTestCase + + +@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc] +def flashinfer_rmsnorm( + input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: + return flashinfer.norm.rmsnorm(input, weight) + + +@torch.library.register_fake("flashinfer::rmsnorm") +def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor: + return input + + +torch_tensorrt.dynamo.conversion.plugins.custom_op( + "flashinfer::rmsnorm", supports_dynamic_shapes=True +) + + +class TestAutomaticPlugin(DispatchTestCase): + @parameterized.expand( + [ + ((64, 64), (64,), torch.float16), + ((256, 256), (256,), torch.float16), + ] + ) + def test_rmsnorm_float(self, input_shape, weight_shape, data_type): + class rmsnorm(nn.Module): + def forward(self, input, weight): + return torch.ops.flashinfer.rmsnorm.default(input, weight) + + inputs = [ + torch.randn(input_shape, device="cuda", dtype=data_type), + torch.randn(weight_shape, device="cuda", dtype=data_type), + ] + + self.run_test(rmsnorm(), inputs, precision=dtype.f16) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/requirements.txt b/tests/py/requirements.txt index 4f3c4e083b..011ed01e35 100644 --- a/tests/py/requirements.txt +++ b/tests/py/requirements.txt @@ -8,6 +8,7 @@ pytest>=8.2.1 pytest-xdist>=3.6.1 pyyaml timm>=1.0.3 +flashinfer-python; python_version < "3.13" transformers==4.49.0 nvidia-modelopt[deploy,hf,torch]~=0.17.0; python_version < "3.13" --extra-index-url https://pypi.nvidia.com