Skip to content

Commit 769f0a5

Browse files
cascade812Yuqi Zhang
authored and
Yuqi Zhang
committed
[Feature]Add async tensor parallelism using compilation pass (vllm-project#17882)
Signed-off-by: cascade812 <[email protected]> Signed-off-by: Yuqi Zhang <[email protected]>
1 parent 65e2f72 commit 769f0a5

11 files changed

+472
-56
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ steps:
316316
- pytest -v -s compile/test_fusion.py
317317
- pytest -v -s compile/test_silu_mul_quant_fusion.py
318318
- pytest -v -s compile/test_sequence_parallelism.py
319+
- pytest -v -s compile/test_async_tp.py
319320

320321
- label: PyTorch Fullgraph Smoke Test # 9min
321322
mirror_hardwares: [amdexperimental, amdproduction]

tests/compile/backend.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from torch import fx
77

8+
from vllm.compilation.fx_utils import (find_specified_fn,
9+
find_specified_fn_maybe)
810
from vllm.compilation.inductor_pass import InductorPass
911
from vllm.config import get_current_vllm_config
1012

@@ -44,3 +46,19 @@ def post_pass(self, graph: fx.Graph):
4446
self.graph_post_pass = deepcopy(graph)
4547
# assign by reference, will reflect the final state of the graph
4648
self.final_graph = graph
49+
50+
def check_before_ops(self, ops,
51+
find_fn=find_specified_fn, \
52+
find_fn_maybe=find_specified_fn_maybe, \
53+
ops_fully_replaced=True):
54+
for op in ops:
55+
find_fn(self.graph_pre_pass.nodes, op)
56+
if ops_fully_replaced:
57+
assert find_fn_maybe(self.graph_post_pass.nodes, op) is None
58+
59+
def check_after_ops(self, ops,
60+
find_fn=find_specified_fn, \
61+
find_fn_maybe=find_specified_fn_maybe):
62+
for op in ops:
63+
find_fn(self.graph_post_pass.nodes, op)
64+
assert find_fn_maybe(self.graph_pre_pass.nodes, op) is None

tests/compile/test_async_tp.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import json
4+
5+
import pytest
6+
import torch
7+
8+
import vllm.envs as envs
9+
from vllm.compilation.collective_fusion import AsyncTPPass
10+
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
11+
PassConfig, VllmConfig)
12+
from vllm.distributed import (tensor_model_parallel_all_gather,
13+
tensor_model_parallel_reduce_scatter)
14+
from vllm.distributed.parallel_state import (init_distributed_environment,
15+
initialize_model_parallel)
16+
from vllm.platforms import current_platform
17+
from vllm.utils import update_environment_variables
18+
19+
from ..models.registry import HF_EXAMPLE_MODELS
20+
from ..utils import (compare_two_settings, create_new_process_for_each_test,
21+
multi_gpu_test)
22+
from .backend import TestBackend
23+
24+
prompts = [
25+
"Hello, my name is",
26+
"The president of the United States is",
27+
"The capital of France is",
28+
"The future of AI is",
29+
]
30+
31+
32+
class TestMMRSModel(torch.nn.Module):
33+
34+
def __init__(self, hidden_size=16):
35+
super().__init__()
36+
self.hidden_size = hidden_size
37+
self.gate_proj = torch.nn.Parameter(torch.empty(
38+
(self.hidden_size * 2, hidden_size)),
39+
requires_grad=False)
40+
# Initialize weights
41+
torch.nn.init.normal_(self.gate_proj, std=0.02)
42+
43+
def forward(self, hidden_states):
44+
"""
45+
Forward pass implementing the mm + reduce scatter in the FX graph
46+
47+
"""
48+
# Reshape input
49+
view = hidden_states.reshape(-1, self.hidden_size)
50+
51+
# matrix multiplication
52+
permute = self.gate_proj.permute(1, 0)
53+
mm = torch.mm(view, permute)
54+
reduce_scatter = tensor_model_parallel_reduce_scatter(mm, dim=0)
55+
return reduce_scatter
56+
57+
def ops_in_model_before(self):
58+
return [torch.ops.vllm.reduce_scatter.default]
59+
60+
def ops_in_model_after(self):
61+
return [torch.ops.symm_mem.fused_matmul_reduce_scatter.default]
62+
63+
64+
class TestAGMMModel(torch.nn.Module):
65+
66+
def __init__(self, hidden_size=16):
67+
super().__init__()
68+
self.hidden_size = hidden_size
69+
self.weight = torch.nn.Parameter(torch.empty(
70+
(hidden_size, hidden_size)),
71+
requires_grad=False)
72+
# Initialize weights
73+
torch.nn.init.normal_(self.weight, std=0.02)
74+
75+
def forward(self, hidden_states):
76+
"""
77+
Forward pass implementing the mm + all gather in the FX graph
78+
"""
79+
# Reshape input
80+
view = hidden_states.reshape(-1, self.hidden_size)
81+
all_gather = tensor_model_parallel_all_gather(view, dim=0)
82+
permute = self.weight.permute(1, 0)
83+
mm = torch.mm(all_gather, permute)
84+
return mm
85+
86+
def ops_in_model_before(self):
87+
return [torch.ops.vllm.all_gather.default]
88+
89+
def ops_in_model_after(self):
90+
return [torch.ops.symm_mem.fused_all_gather_matmul.default]
91+
92+
93+
@multi_gpu_test(num_gpus=2)
94+
@pytest.mark.parametrize("test_model", [TestMMRSModel, TestAGMMModel])
95+
@pytest.mark.parametrize("batch_size", [8])
96+
@pytest.mark.parametrize("seq_len", [16])
97+
@pytest.mark.parametrize("hidden_size", [16])
98+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
99+
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
100+
reason="Only test on CUDA")
101+
def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
102+
hidden_size: int, dtype: torch.dtype):
103+
num_processes = 2
104+
105+
def run_torch_spawn(fn, nprocs):
106+
# need to use torch.mp.spawn otherwise will have problems with
107+
# torch.distributed and cuda
108+
torch.multiprocessing.spawn(fn,
109+
args=(num_processes, test_model,
110+
batch_size, seq_len, hidden_size,
111+
dtype),
112+
nprocs=nprocs)
113+
114+
run_torch_spawn(async_tp_pass_on_test_model, num_processes)
115+
116+
117+
def async_tp_pass_on_test_model(local_rank: int, world_size: int,
118+
test_model_cls: torch.nn.Module,
119+
batch_size: int, seq_len: int,
120+
hidden_size: int, dtype: torch.dtype):
121+
current_platform.seed_everything(0)
122+
123+
device = torch.device(f"cuda:{local_rank}")
124+
torch.cuda.set_device(device)
125+
torch.set_default_device(device)
126+
torch.set_default_dtype(dtype)
127+
128+
update_environment_variables({
129+
'RANK': str(local_rank),
130+
'LOCAL_RANK': str(local_rank),
131+
'WORLD_SIZE': str(world_size),
132+
'MASTER_ADDR': 'localhost',
133+
'MASTER_PORT': '12345',
134+
})
135+
136+
# initialize distributed
137+
init_distributed_environment()
138+
initialize_model_parallel(tensor_model_parallel_size=world_size)
139+
140+
# configure vllm config for SequenceParallelismPass
141+
vllm_config = VllmConfig()
142+
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
143+
enable_async_tp=True, ), )
144+
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
145+
146+
# this is a fake model name to construct the model config
147+
# in the vllm_config, it's not really used.
148+
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
149+
vllm_config.model_config = ModelConfig(model=model_name,
150+
task="auto",
151+
tokenizer=model_name,
152+
tokenizer_mode="auto",
153+
trust_remote_code=True,
154+
dtype=dtype,
155+
seed=42)
156+
157+
async_tp_pass = AsyncTPPass(vllm_config)
158+
backend = TestBackend(async_tp_pass)
159+
160+
model = test_model_cls(hidden_size)
161+
162+
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
163+
dtype=dtype,
164+
requires_grad=False)
165+
166+
compiled_model = torch.compile(model, backend=backend)
167+
compiled_model(hidden_states)
168+
169+
# In pre-nodes, all gather or reduce scatter should exist,
170+
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
171+
backend.check_before_ops(model.ops_in_model_before(),
172+
ops_fully_replaced=False)
173+
174+
# In post-nodes, fused_matmul_reduce_scatter or \
175+
# fused_all_gather_matmul should exist
176+
backend.check_after_ops(model.ops_in_model_after())
177+
178+
179+
@create_new_process_for_each_test()
180+
@pytest.mark.parametrize("model_id", ["meta-llama/Llama-3.2-1B-Instruct"])
181+
@pytest.mark.parametrize("tp_size", [2])
182+
@pytest.mark.parametrize("async_tp_enabled", [True])
183+
@pytest.mark.parametrize("distributed_backend", ["mp"])
184+
@pytest.mark.parametrize("eager_mode", [False, True])
185+
def test_async_tp_pass_correctness(
186+
model_id: str,
187+
tp_size: int,
188+
async_tp_enabled: bool,
189+
distributed_backend: str,
190+
eager_mode: bool,
191+
num_gpus_available: int,
192+
):
193+
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
194+
model_info.check_transformers_version(on_fail="skip")
195+
model_info.check_available_online(on_fail="skip")
196+
197+
pp_size = 1
198+
if num_gpus_available < tp_size:
199+
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
200+
201+
common_args = [
202+
"--dtype",
203+
"bfloat16",
204+
"--max-model-len",
205+
"2048",
206+
"--max-num-seqs",
207+
"8",
208+
]
209+
if eager_mode:
210+
common_args.append("--enforce-eager")
211+
212+
compilation_config = {
213+
'level': 3,
214+
'compile_sizes': [2, 4, 8],
215+
'splitting_ops': [],
216+
'pass_config': {
217+
'enable_async_tp': async_tp_enabled
218+
},
219+
}
220+
221+
async_tp_env = tp_env = {
222+
"VLLM_USE_V1": "1",
223+
}
224+
225+
aysnc_tp_args = [
226+
*common_args,
227+
"--tensor-parallel-size",
228+
str(tp_size),
229+
"--distributed-executor-backend",
230+
distributed_backend,
231+
"--compilation_config",
232+
json.dumps(compilation_config),
233+
]
234+
235+
tp_args = [
236+
*common_args,
237+
"--tensor-parallel-size",
238+
str(tp_size),
239+
"--distributed-executor-backend",
240+
"mp",
241+
]
242+
243+
compare_two_settings(model_id,
244+
aysnc_tp_args,
245+
tp_args,
246+
async_tp_env,
247+
tp_env,
248+
method="generate")

tests/compile/test_fusion.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ def __init__(self, hidden_size: int, eps: float, static: bool,
2929
self.cutlass_fp8_enabled = cutlass_fp8_enabled
3030
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
3131
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
32+
self.key = QuantKey(dtype=FP8_DTYPE,
33+
static=static,
34+
per_tensor=static,
35+
symmetric=True)
3236
if static:
3337
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
3438
else:
@@ -59,6 +63,15 @@ def forward(self, x):
5963
y3, resid = self.norm[2](x3, resid) # use resid here
6064
return y3
6165

66+
def ops_in_model_before(self):
67+
return [QUANT_OPS[self.key]]
68+
69+
def ops_in_model_after(self):
70+
return [
71+
FUSED_OPS[FusedRMSQuantKey(self.key, False)],
72+
FUSED_OPS[FusedRMSQuantKey(self.key, True)]
73+
]
74+
6275

6376
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
6477
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
@@ -107,25 +120,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
107120

108121
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
109122

110-
# Check substitution worked
111-
pre_nodes = backend.graph_pre_pass.nodes
112-
post_nodes = backend.graph_post_pass.nodes
113-
114-
# static is per-tensor, dynamic is per-token
115-
key = QuantKey(dtype=FP8_DTYPE,
116-
static=static,
117-
per_tensor=static,
118-
symmetric=True)
119-
rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)]
120-
add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
121-
fp8_quant = QUANT_OPS[key]
122-
123123
# In pre-nodes, fp8 quant should be there and fused kernels should not
124-
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
125-
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
126-
find_auto_fn(pre_nodes, fp8_quant)
124+
backend.check_before_ops(model.ops_in_model_before(), find_auto_fn,
125+
find_auto_fn_maybe)
127126

128127
# In post-nodes, fused kernels should be there and fp8 quant should not
129-
find_auto_fn(post_nodes, rms_quant)
130-
find_auto_fn(post_nodes, add_rms_quant)
131-
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
128+
backend.check_after_ops(model.ops_in_model_after(), find_auto_fn,
129+
find_auto_fn_maybe)

0 commit comments

Comments
 (0)