Skip to content

Commit bd56c98

Browse files
[torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename RedundantReshapesPass to NoopEliminationPass (#10902)
Signed-off-by: luka <[email protected]>
1 parent 084bbac commit bd56c98

File tree

9 files changed

+249
-170
lines changed

9 files changed

+249
-170
lines changed

tests/compile/backend.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,26 @@ class TestBackend:
1313
This class provides a simple Inductor backend that can be used for testing.
1414
It takes a list of custom passes and runs them after Inductor's passes.
1515
It also saves the graph before and after the custom passes for inspection.
16+
17+
Inductor config can be modified directly by editing the inductor_config
18+
property. This can be helpful for adding passes like the
19+
'pre_grad_custom_pass' and the 'post_grad_custom_pre_pass'.
1620
"""
1721

1822
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
1923
None]]):
2024
self.custom_passes = list(passes)
2125
from torch._inductor import config
22-
self.current_config = config.shallow_copy_dict()
23-
self.current_config['force_disable_caches'] = True
24-
self.current_config['post_grad_custom_post_pass'] = self.post_pass
26+
self.inductor_config = config.shallow_copy_dict()
27+
self.inductor_config['force_disable_caches'] = True
28+
self.inductor_config['post_grad_custom_post_pass'] = self.post_pass
2529

2630
def __call__(self, graph: fx.GraphModule, example_inputs):
31+
self.graph_pre_compile = deepcopy(graph)
2732
from torch._inductor.compile_fx import compile_fx
2833
return compile_fx(graph,
2934
example_inputs,
30-
config_patches=self.current_config)
35+
config_patches=self.inductor_config)
3136

3237
def post_pass(self, graph: fx.Graph):
3338
self.graph_pre_pass = deepcopy(graph)

tests/compile/test_functionalization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
1010
kFp8DynamicTokenSym, kFp8StaticTensorSym)
1111
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
12-
from vllm.compilation.reshapes import RedundantReshapesPass
12+
from vllm.compilation.noop_elimination import NoOpEliminationPass
1313
from vllm.config import CompilationConfig
1414

1515
from .backend import TestBackend
@@ -50,11 +50,11 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
5050
torch.set_default_device("cuda")
5151

5252
config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
53-
enable_reshape=True)
54-
reshape_pass = RedundantReshapesPass(config)
53+
enable_noop=True)
54+
noop_pass = NoOpEliminationPass(config)
5555
fusion_pass = FusionPass.instance(config)
5656

57-
passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass]
57+
passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass]
5858
func_pass = FixFunctionalizationPass(config)
5959
backend_func = TestBackend(*passes, func_pass)
6060
backend_no_func = TestBackend(*passes)

tests/compile/test_fusion.py

Lines changed: 69 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,25 @@
55
from compressed_tensors.quantization import FP8_DTYPE
66

77
import vllm.envs as envs
8+
import vllm.plugins
89
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
910
FusionPass, QuantKey)
1011
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
11-
from vllm.compilation.reshapes import RedundantReshapesPass
12-
from vllm.config import CompilationConfig
12+
from vllm.compilation.noop_elimination import NoOpEliminationPass
13+
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
1314
from vllm.model_executor.layers.layernorm import RMSNorm
1415
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
15-
apply_fp8_linear)
16+
CUTLASS_FP8_SUPPORTED, apply_fp8_linear, maybe_create_device_identity)
1617

1718
from .backend import TestBackend
1819

1920

2021
class TestModel(torch.nn.Module):
2122

22-
def __init__(self, hidden_size: int, eps: float, static: bool, *args,
23-
**kwargs):
23+
def __init__(self, hidden_size: int, eps: float, static: bool,
24+
cutlass_fp8_enabled: bool, *args, **kwargs):
2425
super().__init__(*args, **kwargs)
26+
self.cutlass_fp8_enabled = cutlass_fp8_enabled
2527
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
2628
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
2729
if static:
@@ -41,15 +43,17 @@ def forward(self, x):
4143
self.w[0],
4244
self.wscale[0],
4345
self.scale[0],
44-
use_per_token_if_dynamic=True)
46+
use_per_token_if_dynamic=True,
47+
cutlass_fp8_supported=self.cutlass_fp8_enabled)
4548
# make sure resid is used for replacement to work
4649
y2, resid = self.norm[1](x2, resid)
4750

4851
x3 = apply_fp8_linear(y2,
4952
self.w[1],
5053
self.wscale[1],
5154
self.scale[1],
52-
use_per_token_if_dynamic=True)
55+
use_per_token_if_dynamic=True,
56+
cutlass_fp8_supported=self.cutlass_fp8_enabled)
5357
y3, resid = self.norm[2](x3, resid) # use resid here
5458
return y3
5559

@@ -59,60 +63,67 @@ def forward(self, x):
5963
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
6064
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
6165
@pytest.mark.parametrize("static", [True, False])
66+
@pytest.mark.parametrize("cutlass_fp8_enabled",
67+
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
6268
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
6369
reason="Only test on CUDA")
64-
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static):
70+
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
71+
cutlass_fp8_enabled):
6572
torch.set_default_device("cuda")
6673
torch.set_default_dtype(dtype)
6774
torch.manual_seed(1)
75+
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
6876

69-
# Reshape pass is needed for the fusion pass to work
70-
config = CompilationConfig.PassConfig(enable_fusion=True,
71-
enable_reshape=True)
72-
reshape_pass = RedundantReshapesPass(config)
73-
fusion_pass = FusionPass.instance(config)
74-
75-
backend = TestBackend(reshape_pass, fusion_pass)
76-
model = TestModel(hidden_size, eps, static)
77-
78-
# First dimension dynamic
79-
x = torch.rand(num_tokens, hidden_size)
80-
torch._dynamo.mark_dynamic(x, 0)
81-
82-
result = model(x)
83-
84-
model2 = torch.compile(model, backend=backend)
85-
result2 = model2(x)
86-
87-
# Higher tol for dynamic, even higher for bfloat16
88-
if static:
89-
ATOL, RTOL = (1e-3, 1e-3)
90-
elif dtype == torch.float16:
91-
ATOL, RTOL = (2e-3, 2e-3)
92-
else:
93-
ATOL, RTOL = (1e-2, 1e-2)
94-
95-
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
96-
97-
# Check substitution worked
98-
pre_nodes = backend.graph_pre_pass.nodes
99-
post_nodes = backend.graph_post_pass.nodes
100-
101-
# static is per-tensor, dynamic is per-token
102-
key = QuantKey(dtype=FP8_DTYPE,
103-
static=static,
104-
per_tensor=static,
105-
symmetric=True)
106-
rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)]
107-
add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
108-
fp8_quant = QUANT_OPS[key]
109-
110-
# In pre-nodes, fp8 quant should be present and fused kernels should not
111-
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
112-
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
113-
find_auto_fn(pre_nodes, fp8_quant)
114-
115-
# In post-nodes, fused kernels should be present and fp8 quant should not
116-
find_auto_fn(post_nodes, rms_quant)
117-
find_auto_fn(post_nodes, add_rms_quant)
118-
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
77+
vllm_config = VllmConfig(compilation_config=CompilationConfig(
78+
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
79+
with vllm.config.set_current_vllm_config(vllm_config):
80+
# Reshape pass is needed for the fusion pass to work
81+
config = CompilationConfig.PassConfig(enable_fusion=True,
82+
enable_noop=True)
83+
noop_pass = NoOpEliminationPass(config)
84+
fusion_pass = FusionPass.instance(config)
85+
86+
backend = TestBackend(noop_pass, fusion_pass)
87+
model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled)
88+
89+
# First dimension dynamic
90+
x = torch.rand(num_tokens, hidden_size)
91+
torch._dynamo.mark_dynamic(x, 0)
92+
93+
result = model(x)
94+
95+
model2 = torch.compile(model, backend=backend)
96+
result2 = model2(x)
97+
98+
# Higher tol for dynamic, even higher for bfloat16
99+
if static:
100+
ATOL, RTOL = (1e-3, 1e-3)
101+
elif dtype == torch.float16:
102+
ATOL, RTOL = (2e-3, 2e-3)
103+
else:
104+
ATOL, RTOL = (1e-2, 1e-2)
105+
106+
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
107+
108+
# Check substitution worked
109+
pre_nodes = backend.graph_pre_pass.nodes
110+
post_nodes = backend.graph_post_pass.nodes
111+
112+
# static is per-tensor, dynamic is per-token
113+
key = QuantKey(dtype=FP8_DTYPE,
114+
static=static,
115+
per_tensor=static,
116+
symmetric=True)
117+
rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)]
118+
add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
119+
fp8_quant = QUANT_OPS[key]
120+
121+
# In pre-nodes, fp8 quant should be there and fused kernels should not
122+
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
123+
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
124+
find_auto_fn(pre_nodes, fp8_quant)
125+
126+
# In post-nodes, fused kernels should be there and fp8 quant should not
127+
find_auto_fn(post_nodes, rms_quant)
128+
find_auto_fn(post_nodes, add_rms_quant)
129+
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None

vllm/compilation/noop_elimination.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from typing import Iterable, Union
4+
5+
import torch.fx
6+
from torch import SymInt
7+
8+
from vllm.logger import init_logger
9+
10+
from .fx_utils import is_func
11+
from .vllm_inductor_pass import VllmInductorPass
12+
13+
logger = init_logger(__name__)
14+
15+
16+
class NoOpEliminationPass(VllmInductorPass):
17+
"""
18+
This is an inductor pass that removes redundant reshape/slice operations.
19+
It is required for RMSNorm-quant fusion to work properly.
20+
That's because apply_fp8_linear adds a reshape, which is redundant
21+
in the 2D-case. Additionally, torch internal no-op elimination pass does
22+
not handle certain slice variants.
23+
24+
Example graph 1:
25+
getitem_1: "f16[s0, 4096]" = ...
26+
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
27+
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
28+
out: "f8e4m3fn[s0, 4096]" = at[1]
29+
30+
Can be replaced with:
31+
getitem_1: "f16[s0, 4096]" = ...
32+
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
33+
out: "f8e4m3fn[s0, 4096]" = at[1]
34+
35+
Example graph 2:
36+
arg0: "s0" = SymInt(s0)
37+
scaled_mm: "f16[s0, 4096]" = ...
38+
slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
39+
at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...)
40+
out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0)
41+
42+
Can be replaced with:
43+
arg0: "s0" = SymInt(s0)
44+
scaled_mm: "f16[s0, 4096]" = ...
45+
at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
46+
out: "f16[s0, 4096]" = at[1]
47+
48+
TODO(luka): This is currently tested in test_fusion,
49+
but separate tests could be good.
50+
"""
51+
52+
def __call__(self, graph: torch.fx.Graph):
53+
self.begin()
54+
self.dump_graph(graph, "before_noop_elimination")
55+
count = 0
56+
# Remove no-op reshapes/views:
57+
for node in graph.nodes:
58+
if is_func(node, torch.ops.aten.reshape.default):
59+
input, shape = node.args[:2]
60+
input_shape = input.meta["val"].shape
61+
if len(shape) != len(input_shape):
62+
# Reshape changing rank, skip
63+
continue
64+
65+
if shape.count(-1) > 1:
66+
# Invalid reshape args, skip
67+
continue
68+
69+
if self.all_dims_equivalent(shape, input_shape):
70+
node.replace_all_uses_with(input)
71+
graph.erase_node(node)
72+
count += 1
73+
74+
elif is_func(node, torch.ops.aten.slice.Tensor):
75+
input, dim_index, start, end = node.args[:4]
76+
input_shape = input.meta["val"].shape
77+
i_dim = input_shape[dim_index]
78+
79+
if start == 0 and self.dims_equivalent(end, i_dim):
80+
node.replace_all_uses_with(input)
81+
graph.erase_node(node)
82+
count += 1
83+
84+
elif is_func(node, torch.ops.aten.slice_scatter.default):
85+
base, view, dim_index, start, end = node.args[:5]
86+
base_shape = base.meta["val"].shape
87+
view_shape = view.meta["val"].shape
88+
89+
view_dim = view_shape[dim_index]
90+
91+
# Check that view fully covers base and the full view is used
92+
# (if the view fully covered the base after slicing but was not
93+
# fully used, we could replace slice_scatter with a simple slice
94+
# but that's a niche case).
95+
if (base_shape == view_shape and start == 0
96+
and self.dims_equivalent(end, view_dim)):
97+
node.replace_all_uses_with(view)
98+
graph.erase_node(node)
99+
count += 1
100+
101+
logger.debug("Removed %s no-op reshapes and slices", count)
102+
self.dump_graph(graph, "after_noop_elimination")
103+
self.end_and_log()
104+
105+
def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]],
106+
i_dims: Iterable[Union[int, SymInt]]):
107+
return all(
108+
self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))
109+
110+
def dims_equivalent(self, dim: Union[int, torch.fx.Node],
111+
i_dim: Union[int, SymInt]) -> bool:
112+
"""
113+
This function checks if two dimensions are equivalent.
114+
:param dim: The dimension arg to reshape/slice
115+
:param i_dim: The corresponding dimension in the input tensor
116+
:return: Are the dimensions equivalent?
117+
118+
There are three cases in which the dimensions are equivalent:
119+
1. The dimensions are equal (both integers)
120+
2. The reshape dimension is -1 (i.e. inferred)
121+
3. The dimensions both correspond to the same SymInt
122+
123+
While case 2 does not guarantee the dimensions are equal,
124+
they are equal if all other dimensions are equal.
125+
126+
In case 3, the reshape dimension is a torch.fx.Node,
127+
and its value is a SymInt. That value is equal to the
128+
input dimension.
129+
130+
"""
131+
# Case 1 and 2
132+
if dim == i_dim or dim == -1:
133+
return True
134+
# Case 3
135+
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim

vllm/compilation/pass_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .fix_functionalization import FixFunctionalizationPass
1212
from .fusion import FusionPass
1313
from .inductor_pass import InductorPass
14-
from .reshapes import RedundantReshapesPass
14+
from .noop_elimination import NoOpEliminationPass
1515

1616
logger = init_logger(__name__)
1717

@@ -36,7 +36,7 @@ class PostGradPassManager(Parent):
3636
3737
The order of the post-grad post-passes is:
3838
1. passes (constructor parameter)
39-
2. default passes (RedundantReshapesPass, FusionPass)
39+
2. default passes (NoopEliminationPass, FusionPass)
4040
3. config["post_grad_custom_post_pass"] (if it exists)
4141
4. fix_functionalization
4242
This way, all passes operate on a functionalized graph.
@@ -54,8 +54,8 @@ def __call__(self, graph: fx.Graph):
5454

5555
def configure(self, pass_config: CompilationConfig.PassConfig):
5656
self.pass_config = pass_config
57-
if pass_config.enable_reshape:
58-
self.passes += [RedundantReshapesPass(pass_config)]
57+
if pass_config.enable_noop:
58+
self.passes += [NoOpEliminationPass(pass_config)]
5959

6060
if pass_config.enable_fusion:
6161
self.passes += [FusionPass.instance(pass_config)]

0 commit comments

Comments
 (0)