Skip to content

[V1] EP/TP MoE + DP Attention #13931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions examples/offline_inference/data_parallel.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# usage: VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py
# usage:
# VLLM_TEST_ENABLE_EP=1 VLLM_USE_V1=1 \
# python examples/offline_inference/data_parallel.py
# we need to have a launcher to create multiple data parallel
# ranks. And each rank will create a vLLM instance to process its own prompts.
import os

from vllm import LLM, SamplingParams
from vllm.utils import get_open_port

GPUs_per_dp_rank = 2
DP_size = 2


def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
os.environ["VLLM_DP_RANK"] = str(dp_rank)
Expand Down Expand Up @@ -48,8 +53,8 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
max_tokens=16 * (dp_rank + 1))

# Create an LLM.
llm = LLM(model="facebook/opt-125m",
tensor_parallel_size=2,
llm = LLM(model="ibm-research/PowerMoE-3b",
tensor_parallel_size=GPUs_per_dp_rank,
enforce_eager=True)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
Expand All @@ -62,14 +67,12 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):

if __name__ == "__main__":
from multiprocessing import Process
dp_size = 2
GPUs_per_dp_rank = 2
dp_master_ip = "127.0.0.1"
dp_master_port = get_open_port()
procs = []
for i in range(dp_size):
for i in range(DP_size):
proc = Process(target=main,
args=(dp_size, i, dp_master_ip, dp_master_port,
args=(DP_size, i, dp_master_ip, dp_master_port,
GPUs_per_dp_rank))
proc.start()
procs.append(proc)
Expand Down
1 change: 1 addition & 0 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def test_mixtral_moe(dtype: torch.dtype):
intermediate_size=config.intermediate_size,
params_dtype=dtype,
tp_size=1,
dp_size=1,
).cuda()

# Load the weights
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def unified_attention(
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
self = forward_context.attn_layers[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(self, query, key, value, kv_cache, attn_metadata)

Expand Down Expand Up @@ -356,7 +356,7 @@ def unified_attention_with_output(
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
self = forward_context.attn_layers[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
query,
Expand Down
5 changes: 3 additions & 2 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,9 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

cache_dir = self.compilation_config.cache_dir
os.makedirs(cache_dir, exist_ok=True)
local_cache_dir = os.path.join(
cache_dir, f"rank_{vllm_config.parallel_config.rank}")
rank = vllm_config.parallel_config.rank
dp_rank = vllm_config.parallel_config.data_parallel_rank
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
self.compilation_config.local_cache_dir = local_cache_dir

disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE
Expand Down
22 changes: 15 additions & 7 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,22 @@
batchsize_forward_time: defaultdict = defaultdict(list)


@dataclass
class DPMetadata:
num_tokens_across_dp: list[int]
cu_tokens_across_dp_cpu: torch.Tensor


@dataclass
class ForwardContext:
# copy from vllm_config.compilation_config.static_forward_context
attn_layers: dict[str, Any]
no_compile_layers: dict[str, Any]
# TODO: extend to support per-layer dynamic forward context
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass
num_tokens_across_dp: Optional[
list[int]] = None # set dynamically for each forward pass
# set dynamically for each forward pass
dp_metadata: Optional[DPMetadata] = None


_forward_context: Optional[ForwardContext] = None
Expand All @@ -61,7 +67,7 @@ def set_forward_context(attn_metadata: Any,
need_to_track_batchsize = track_batchsize and attn_metadata is not None
if need_to_track_batchsize:
forward_start_time = time.perf_counter()
num_tokens_across_dp = None
dp_metadata: Optional[DPMetadata] = None
if vllm_config.parallel_config.data_parallel_size > 1:
dp_size = vllm_config.parallel_config.data_parallel_size
dp_rank = vllm_config.parallel_config.data_parallel_rank
Expand All @@ -82,15 +88,17 @@ def set_forward_context(attn_metadata: Any,
dtype=torch.int32)
from vllm.distributed.parallel_state import get_dp_group
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
num_tokens_across_dp = num_tokens_tensor.tolist()
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
dp_metadata = DPMetadata(num_tokens_across_dp, cu_tokens_across_dp_cpu)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like num_tokens_across_dp isn't getting updated from the tensor anymore? It'll just contain the local rank value?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are absolutely right, thank you for catching that


global _forward_context
prev_context = _forward_context
_forward_context = ForwardContext(
attn_layers=vllm_config.compilation_config.static_forward_context,
no_compile_layers=vllm_config.compilation_config.
static_forward_context,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
num_tokens_across_dp=num_tokens_across_dp)
dp_metadata=dp_metadata)
try:
yield
finally:
Expand Down
Loading