Skip to content

Commit 5e38c04

Browse files
committed
test: support automatic plugin feature with different dimensions and add flashinfer.rmsnorm support test case
1 parent 29b65b0 commit 5e38c04

File tree

5 files changed

+55
-10
lines changed

5 files changed

+55
-10
lines changed

Diff for: .github/workflows/build-test-linux.yml

+1
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ jobs:
142142
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/
143143
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_automatic_plugin.py
144144
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_automatic_plugin_with_attrs.py
145+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_flashinfer_rmsnorm.py
145146
popd
146147
147148
tests-py-dynamo-fe:

Diff for: py/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ pybind11==2.6.2
55
torch>=2.8.0.dev,<2.9.0
66
torchvision>=0.22.0.dev,<0.23.0
77
--extra-index-url https://pypi.ngc.nvidia.com
8-
pyyaml
8+
pyyaml

Diff for: tests/py/dynamo/automatic_plugin/test_automatic_plugin.py

-9
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,3 @@ def forward(self, lhs, rhs):
8181

8282
if __name__ == "__main__":
8383
run_tests()
84-
85-
# Example Usage
86-
# A = torch.full((64, 64), 2, device="cuda", dtype=torch.float)
87-
# B = torch.full((64, 64), 3, device="cuda", dtype=torch.float)
88-
89-
# C, D = torch.ops.torchtrt_ex.elementwise_add_mul.default(A, B)
90-
91-
# print("C (Addition):", C)
92-
# print("D (Multiplication):", D)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import pytest
2+
3+
flashinfer = pytest.importorskip("flashinfer")
4+
import torch
5+
import torch.nn as nn
6+
import torch_tensorrt
7+
from parameterized import parameterized
8+
from torch.testing._internal.common_utils import run_tests
9+
from torch_tensorrt._enums import dtype
10+
11+
from ..conversion.harness import DispatchTestCase
12+
13+
14+
@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc]
15+
def flashinfer_rmsnorm(
16+
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
17+
) -> torch.Tensor:
18+
return flashinfer.norm.rmsnorm(input, weight)
19+
20+
21+
@torch.library.register_fake("flashinfer::rmsnorm")
22+
def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor:
23+
return input
24+
25+
26+
torch_tensorrt.dynamo.conversion.plugins.custom_op(
27+
"flashinfer::rmsnorm", supports_dynamic_shapes=True
28+
)
29+
30+
31+
class TestAutomaticPlugin(DispatchTestCase):
32+
@parameterized.expand(
33+
[
34+
((64, 64), (64,), torch.float16),
35+
((256, 256), (256,), torch.float16),
36+
]
37+
)
38+
def test_rmsnorm_float(self, input_shape, weight_shape, data_type):
39+
class rmsnorm(nn.Module):
40+
def forward(self, input, weight):
41+
return torch.ops.flashinfer.rmsnorm.default(input, weight)
42+
43+
inputs = [
44+
torch.randn(input_shape, device="cuda", dtype=data_type),
45+
torch.randn(weight_shape, device="cuda", dtype=data_type),
46+
]
47+
48+
self.run_test(rmsnorm(), inputs, precision=dtype.f16)
49+
50+
51+
if __name__ == "__main__":
52+
run_tests()

Diff for: tests/py/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ pytest>=8.2.1
88
pytest-xdist>=3.6.1
99
pyyaml
1010
timm>=1.0.3
11+
flashinfer-python; python_version < "3.13"
1112
transformers==4.49.0
1213
nvidia-modelopt[deploy,hf,torch]~=0.17.0; python_version < "3.13"
1314
--extra-index-url https://pypi.nvidia.com

0 commit comments

Comments
 (0)