-
-
Notifications
You must be signed in to change notification settings - Fork 7.7k
[V1] EP/TP MoE + DP Attention #13931
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8548f9c
e55a971
018c7f3
84585ea
a93fde6
e6fd1b9
bb4f8ae
33e0ee0
8551ad7
2188480
4fa682b
4a2318a
21eca4c
9a93e2d
6a628cf
cbdf1bb
19e84a5
523f4bf
eb13f62
0c6fb10
32f5b02
1b864de
bd88ae1
a7668fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,16 +25,22 @@ | |
batchsize_forward_time: defaultdict = defaultdict(list) | ||
|
||
|
||
@dataclass | ||
class DPMetadata: | ||
num_tokens_across_dp: list[int] | ||
cu_tokens_across_dp_cpu: torch.Tensor | ||
|
||
|
||
@dataclass | ||
class ForwardContext: | ||
# copy from vllm_config.compilation_config.static_forward_context | ||
attn_layers: dict[str, Any] | ||
no_compile_layers: dict[str, Any] | ||
# TODO: extend to support per-layer dynamic forward context | ||
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass | ||
# TODO: remove after making all virtual_engines share the same kv cache | ||
virtual_engine: int # set dynamically for each forward pass | ||
num_tokens_across_dp: Optional[ | ||
list[int]] = None # set dynamically for each forward pass | ||
# set dynamically for each forward pass | ||
dp_metadata: Optional[DPMetadata] = None | ||
|
||
|
||
_forward_context: Optional[ForwardContext] = None | ||
|
@@ -61,7 +67,7 @@ def set_forward_context(attn_metadata: Any, | |
need_to_track_batchsize = track_batchsize and attn_metadata is not None | ||
if need_to_track_batchsize: | ||
forward_start_time = time.perf_counter() | ||
num_tokens_across_dp = None | ||
dp_metadata: Optional[DPMetadata] = None | ||
if vllm_config.parallel_config.data_parallel_size > 1: | ||
dp_size = vllm_config.parallel_config.data_parallel_size | ||
dp_rank = vllm_config.parallel_config.data_parallel_rank | ||
|
@@ -82,15 +88,17 @@ def set_forward_context(attn_metadata: Any, | |
dtype=torch.int32) | ||
from vllm.distributed.parallel_state import get_dp_group | ||
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) | ||
num_tokens_across_dp = num_tokens_tensor.tolist() | ||
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) | ||
dp_metadata = DPMetadata(num_tokens_across_dp, cu_tokens_across_dp_cpu) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are absolutely right, thank you for catching that |
||
|
||
global _forward_context | ||
prev_context = _forward_context | ||
_forward_context = ForwardContext( | ||
attn_layers=vllm_config.compilation_config.static_forward_context, | ||
no_compile_layers=vllm_config.compilation_config. | ||
static_forward_context, | ||
virtual_engine=virtual_engine, | ||
attn_metadata=attn_metadata, | ||
num_tokens_across_dp=num_tokens_across_dp) | ||
dp_metadata=dp_metadata) | ||
try: | ||
yield | ||
finally: | ||
|
Uh oh!
There was an error while loading. Please reload this page.