Skip to content

Commit 2595deb

Browse files
tlrmchlsmthshreyankg
authored andcommitted
[V1] EP/TP MoE + DP Attention (vllm-project#13931)
1 parent 6691a77 commit 2595deb

File tree

17 files changed

+250
-75
lines changed

17 files changed

+250
-75
lines changed

examples/offline_inference/data_parallel.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
# usage: VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py
2+
# usage:
3+
# VLLM_TEST_ENABLE_EP=1 VLLM_USE_V1=1 \
4+
# python examples/offline_inference/data_parallel.py
35
# we need to have a launcher to create multiple data parallel
46
# ranks. And each rank will create a vLLM instance to process its own prompts.
57
import os
68

79
from vllm import LLM, SamplingParams
810
from vllm.utils import get_open_port
911

12+
GPUs_per_dp_rank = 2
13+
DP_size = 2
14+
1015

1116
def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
1217
os.environ["VLLM_DP_RANK"] = str(dp_rank)
@@ -48,8 +53,8 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
4853
max_tokens=16 * (dp_rank + 1))
4954

5055
# Create an LLM.
51-
llm = LLM(model="facebook/opt-125m",
52-
tensor_parallel_size=2,
56+
llm = LLM(model="ibm-research/PowerMoE-3b",
57+
tensor_parallel_size=GPUs_per_dp_rank,
5358
enforce_eager=True)
5459
outputs = llm.generate(prompts, sampling_params)
5560
# Print the outputs.
@@ -62,14 +67,12 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
6267

6368
if __name__ == "__main__":
6469
from multiprocessing import Process
65-
dp_size = 2
66-
GPUs_per_dp_rank = 2
6770
dp_master_ip = "127.0.0.1"
6871
dp_master_port = get_open_port()
6972
procs = []
70-
for i in range(dp_size):
73+
for i in range(DP_size):
7174
proc = Process(target=main,
72-
args=(dp_size, i, dp_master_ip, dp_master_port,
75+
args=(DP_size, i, dp_master_ip, dp_master_port,
7376
GPUs_per_dp_rank))
7477
proc.start()
7578
procs.append(proc)

tests/kernels/test_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def test_mixtral_moe(dtype: torch.dtype):
217217
intermediate_size=config.intermediate_size,
218218
params_dtype=dtype,
219219
tp_size=1,
220+
dp_size=1,
220221
).cuda()
221222

222223
# Load the weights

vllm/attention/layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def unified_attention(
324324
) -> torch.Tensor:
325325
forward_context: ForwardContext = get_forward_context()
326326
attn_metadata = forward_context.attn_metadata
327-
self = forward_context.attn_layers[layer_name]
327+
self = forward_context.no_compile_layers[layer_name]
328328
kv_cache = self.kv_cache[forward_context.virtual_engine]
329329
return self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
330330

@@ -356,7 +356,7 @@ def unified_attention_with_output(
356356
) -> None:
357357
forward_context: ForwardContext = get_forward_context()
358358
attn_metadata = forward_context.attn_metadata
359-
self = forward_context.attn_layers[layer_name]
359+
self = forward_context.no_compile_layers[layer_name]
360360
kv_cache = self.kv_cache[forward_context.virtual_engine]
361361
self.impl.forward(self,
362362
query,

vllm/compilation/backends.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,8 +396,9 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
396396

397397
cache_dir = self.compilation_config.cache_dir
398398
os.makedirs(cache_dir, exist_ok=True)
399-
local_cache_dir = os.path.join(
400-
cache_dir, f"rank_{vllm_config.parallel_config.rank}")
399+
rank = vllm_config.parallel_config.rank
400+
dp_rank = vllm_config.parallel_config.data_parallel_rank
401+
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
401402
self.compilation_config.local_cache_dir = local_cache_dir
402403

403404
disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE

vllm/forward_context.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,22 @@
2525
batchsize_forward_time: defaultdict = defaultdict(list)
2626

2727

28+
@dataclass
29+
class DPMetadata:
30+
num_tokens_across_dp: list[int]
31+
cu_tokens_across_dp_cpu: torch.Tensor
32+
33+
2834
@dataclass
2935
class ForwardContext:
3036
# copy from vllm_config.compilation_config.static_forward_context
31-
attn_layers: dict[str, Any]
37+
no_compile_layers: dict[str, Any]
3238
# TODO: extend to support per-layer dynamic forward context
3339
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
3440
# TODO: remove after making all virtual_engines share the same kv cache
3541
virtual_engine: int # set dynamically for each forward pass
36-
num_tokens_across_dp: Optional[
37-
list[int]] = None # set dynamically for each forward pass
42+
# set dynamically for each forward pass
43+
dp_metadata: Optional[DPMetadata] = None
3844

3945

4046
_forward_context: Optional[ForwardContext] = None
@@ -61,7 +67,7 @@ def set_forward_context(attn_metadata: Any,
6167
need_to_track_batchsize = track_batchsize and attn_metadata is not None
6268
if need_to_track_batchsize:
6369
forward_start_time = time.perf_counter()
64-
num_tokens_across_dp = None
70+
dp_metadata: Optional[DPMetadata] = None
6571
if vllm_config.parallel_config.data_parallel_size > 1:
6672
dp_size = vllm_config.parallel_config.data_parallel_size
6773
dp_rank = vllm_config.parallel_config.data_parallel_rank
@@ -82,15 +88,17 @@ def set_forward_context(attn_metadata: Any,
8288
dtype=torch.int32)
8389
from vllm.distributed.parallel_state import get_dp_group
8490
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
85-
num_tokens_across_dp = num_tokens_tensor.tolist()
91+
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
92+
dp_metadata = DPMetadata(num_tokens_across_dp, cu_tokens_across_dp_cpu)
8693

8794
global _forward_context
8895
prev_context = _forward_context
8996
_forward_context = ForwardContext(
90-
attn_layers=vllm_config.compilation_config.static_forward_context,
97+
no_compile_layers=vllm_config.compilation_config.
98+
static_forward_context,
9199
virtual_engine=virtual_engine,
92100
attn_metadata=attn_metadata,
93-
num_tokens_across_dp=num_tokens_across_dp)
101+
dp_metadata=dp_metadata)
94102
try:
95103
yield
96104
finally:

0 commit comments

Comments
 (0)