5
5
from compressed_tensors .quantization import FP8_DTYPE
6
6
7
7
import vllm .envs as envs
8
+ import vllm .plugins
8
9
from vllm .compilation .fusion import (FUSED_OPS , QUANT_OPS , FusedRMSQuantKey ,
9
10
FusionPass , QuantKey )
10
11
from vllm .compilation .fx_utils import find_auto_fn , find_auto_fn_maybe
11
- from vllm .compilation .reshapes import RedundantReshapesPass
12
- from vllm .config import CompilationConfig
12
+ from vllm .compilation .noop_elimination import NoOpEliminationPass
13
+ from vllm .config import CompilationConfig , CompilationLevel , VllmConfig
13
14
from vllm .model_executor .layers .layernorm import RMSNorm
14
15
from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
15
- apply_fp8_linear )
16
+ CUTLASS_FP8_SUPPORTED , apply_fp8_linear , maybe_create_device_identity )
16
17
17
18
from .backend import TestBackend
18
19
19
20
20
21
class TestModel (torch .nn .Module ):
21
22
22
- def __init__ (self , hidden_size : int , eps : float , static : bool , * args ,
23
- ** kwargs ):
23
+ def __init__ (self , hidden_size : int , eps : float , static : bool ,
24
+ cutlass_fp8_enabled : bool , * args , ** kwargs ):
24
25
super ().__init__ (* args , ** kwargs )
26
+ self .cutlass_fp8_enabled = cutlass_fp8_enabled
25
27
self .norm = [RMSNorm (hidden_size , eps ) for _ in range (3 )]
26
28
self .wscale = [torch .rand (1 , dtype = torch .float32 ) for _ in range (2 )]
27
29
if static :
@@ -41,15 +43,17 @@ def forward(self, x):
41
43
self .w [0 ],
42
44
self .wscale [0 ],
43
45
self .scale [0 ],
44
- use_per_token_if_dynamic = True )
46
+ use_per_token_if_dynamic = True ,
47
+ cutlass_fp8_supported = self .cutlass_fp8_enabled )
45
48
# make sure resid is used for replacement to work
46
49
y2 , resid = self .norm [1 ](x2 , resid )
47
50
48
51
x3 = apply_fp8_linear (y2 ,
49
52
self .w [1 ],
50
53
self .wscale [1 ],
51
54
self .scale [1 ],
52
- use_per_token_if_dynamic = True )
55
+ use_per_token_if_dynamic = True ,
56
+ cutlass_fp8_supported = self .cutlass_fp8_enabled )
53
57
y3 , resid = self .norm [2 ](x3 , resid ) # use resid here
54
58
return y3
55
59
@@ -59,60 +63,67 @@ def forward(self, x):
59
63
@pytest .mark .parametrize ("num_tokens" , [7 , 256 , 533 , 2048 , 2049 ])
60
64
@pytest .mark .parametrize ("eps" , [1e-5 , 1e-6 ])
61
65
@pytest .mark .parametrize ("static" , [True , False ])
66
+ @pytest .mark .parametrize ("cutlass_fp8_enabled" ,
67
+ [True , False ] if CUTLASS_FP8_SUPPORTED else [False ])
62
68
@pytest .mark .skipif (envs .VLLM_TARGET_DEVICE != "cuda" ,
63
69
reason = "Only test on CUDA" )
64
- def test_fusion_rmsnorm_quant (dtype , hidden_size , num_tokens , eps , static ):
70
+ def test_fusion_rmsnorm_quant (dtype , hidden_size , num_tokens , eps , static ,
71
+ cutlass_fp8_enabled ):
65
72
torch .set_default_device ("cuda" )
66
73
torch .set_default_dtype (dtype )
67
74
torch .manual_seed (1 )
75
+ maybe_create_device_identity () # needed for certain non-cutlass fp8 paths
68
76
69
- # Reshape pass is needed for the fusion pass to work
70
- config = CompilationConfig .PassConfig (enable_fusion = True ,
71
- enable_reshape = True )
72
- reshape_pass = RedundantReshapesPass (config )
73
- fusion_pass = FusionPass .instance (config )
74
-
75
- backend = TestBackend (reshape_pass , fusion_pass )
76
- model = TestModel (hidden_size , eps , static )
77
-
78
- # First dimension dynamic
79
- x = torch .rand (num_tokens , hidden_size )
80
- torch ._dynamo .mark_dynamic (x , 0 )
81
-
82
- result = model (x )
83
-
84
- model2 = torch .compile (model , backend = backend )
85
- result2 = model2 (x )
86
-
87
- # Higher tol for dynamic, even higher for bfloat16
88
- if static :
89
- ATOL , RTOL = (1e-3 , 1e-3 )
90
- elif dtype == torch .float16 :
91
- ATOL , RTOL = (2e-3 , 2e-3 )
92
- else :
93
- ATOL , RTOL = (1e-2 , 1e-2 )
94
-
95
- torch .testing .assert_close (result , result2 , atol = ATOL , rtol = RTOL )
96
-
97
- # Check substitution worked
98
- pre_nodes = backend .graph_pre_pass .nodes
99
- post_nodes = backend .graph_post_pass .nodes
100
-
101
- # static is per-tensor, dynamic is per-token
102
- key = QuantKey (dtype = FP8_DTYPE ,
103
- static = static ,
104
- per_tensor = static ,
105
- symmetric = True )
106
- rms_quant = FUSED_OPS [FusedRMSQuantKey (key , False )]
107
- add_rms_quant = FUSED_OPS [FusedRMSQuantKey (key , True )]
108
- fp8_quant = QUANT_OPS [key ]
109
-
110
- # In pre-nodes, fp8 quant should be present and fused kernels should not
111
- assert find_auto_fn_maybe (pre_nodes , rms_quant ) is None
112
- assert find_auto_fn_maybe (pre_nodes , add_rms_quant ) is None
113
- find_auto_fn (pre_nodes , fp8_quant )
114
-
115
- # In post-nodes, fused kernels should be present and fp8 quant should not
116
- find_auto_fn (post_nodes , rms_quant )
117
- find_auto_fn (post_nodes , add_rms_quant )
118
- assert find_auto_fn_maybe (post_nodes , fp8_quant ) is None
77
+ vllm_config = VllmConfig (compilation_config = CompilationConfig (
78
+ level = CompilationLevel .PIECEWISE , custom_ops = ["+rms_norm" ]))
79
+ with vllm .config .set_current_vllm_config (vllm_config ):
80
+ # Reshape pass is needed for the fusion pass to work
81
+ config = CompilationConfig .PassConfig (enable_fusion = True ,
82
+ enable_noop = True )
83
+ noop_pass = NoOpEliminationPass (config )
84
+ fusion_pass = FusionPass .instance (config )
85
+
86
+ backend = TestBackend (noop_pass , fusion_pass )
87
+ model = TestModel (hidden_size , eps , static , cutlass_fp8_enabled )
88
+
89
+ # First dimension dynamic
90
+ x = torch .rand (num_tokens , hidden_size )
91
+ torch ._dynamo .mark_dynamic (x , 0 )
92
+
93
+ result = model (x )
94
+
95
+ model2 = torch .compile (model , backend = backend )
96
+ result2 = model2 (x )
97
+
98
+ # Higher tol for dynamic, even higher for bfloat16
99
+ if static :
100
+ ATOL , RTOL = (1e-3 , 1e-3 )
101
+ elif dtype == torch .float16 :
102
+ ATOL , RTOL = (2e-3 , 2e-3 )
103
+ else :
104
+ ATOL , RTOL = (1e-2 , 1e-2 )
105
+
106
+ torch .testing .assert_close (result , result2 , atol = ATOL , rtol = RTOL )
107
+
108
+ # Check substitution worked
109
+ pre_nodes = backend .graph_pre_pass .nodes
110
+ post_nodes = backend .graph_post_pass .nodes
111
+
112
+ # static is per-tensor, dynamic is per-token
113
+ key = QuantKey (dtype = FP8_DTYPE ,
114
+ static = static ,
115
+ per_tensor = static ,
116
+ symmetric = True )
117
+ rms_quant = FUSED_OPS [FusedRMSQuantKey (key , False )]
118
+ add_rms_quant = FUSED_OPS [FusedRMSQuantKey (key , True )]
119
+ fp8_quant = QUANT_OPS [key ]
120
+
121
+ # In pre-nodes, fp8 quant should be there and fused kernels should not
122
+ assert find_auto_fn_maybe (pre_nodes , rms_quant ) is None
123
+ assert find_auto_fn_maybe (pre_nodes , add_rms_quant ) is None
124
+ find_auto_fn (pre_nodes , fp8_quant )
125
+
126
+ # In post-nodes, fused kernels should be there and fp8 quant should not
127
+ find_auto_fn (post_nodes , rms_quant )
128
+ find_auto_fn (post_nodes , add_rms_quant )
129
+ assert find_auto_fn_maybe (post_nodes , fp8_quant ) is None
0 commit comments