Skip to content

Commit d55e446

Browse files
authored
[V1][Spec Decode] Small refactors to improve eagle bookkeeping performance (#18424)
Signed-off-by: qizixi <[email protected]>
1 parent ec82c3e commit d55e446

File tree

3 files changed

+21
-19
lines changed

3 files changed

+21
-19
lines changed

tests/v1/spec_decode/test_eagle.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,12 @@ def test_prepare_inputs():
100100
dtype=torch.int32,
101101
device=device)
102102

103+
# n1 + n2 + n3 - a - b -c
104+
num_tokens = cu_target_query_lens[-1].item() - num_rejected_tokens.sum(
105+
).item()
106+
103107
cu_num_tokens, token_indices = EagleProposer.prepare_inputs(
104-
cu_target_query_lens, num_rejected_tokens)
108+
cu_target_query_lens, num_rejected_tokens, num_tokens)
105109

106110
assert torch.equal(cu_num_tokens, expected_cu_num_tokens)
107111
assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()

vllm/v1/spec_decode/eagle.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def prepare_inputs(
271271
cu_target_query_lens: torch.Tensor,
272272
# [batch_size]
273273
num_rejected_tokens: torch.Tensor,
274+
num_tokens: int,
274275
) -> tuple[torch.Tensor, torch.Tensor]:
275276
# cu_target_query_lens: [0, a, a + b, a + b + c]
276277
# num_rejected_tokens: [n1, n2, n3]
@@ -288,18 +289,13 @@ def prepare_inputs(
288289

289290
# [a - n1, b - n2, c - n3] ->
290291
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
291-
cu_num_tokens = torch.empty_like(cu_target_query_lens)
292+
cu_num_tokens = torch.zeros_like(cu_target_query_lens)
292293
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
293-
cu_num_tokens[0] = 0
294-
295-
# FIXME(woosuk): Avoid synchronization.
296-
num_tokens = cu_num_tokens[-1].item()
297294
token_indices = torch.empty(
298295
num_tokens,
299296
dtype=torch.int32,
300-
device=cu_num_tokens.device,
297+
device=cu_target_query_lens.device,
301298
)
302-
303299
batch_size = num_rejected_tokens.shape[0]
304300
BLOCK_SIZE = 1024
305301
prepare_eagle_input_kernel[(batch_size, )](

vllm/v1/worker/gpu_model_runner.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
from vllm.sampling_params import SamplingType
3535
from vllm.sequence import IntermediateTensors
3636
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
37-
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
38-
is_pin_memory_available)
37+
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
38+
check_use_alibi, is_pin_memory_available)
3939
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
4040
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
4141
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
@@ -281,7 +281,7 @@ def __init__(
281281
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool:
282282
"""
283283
Update the order of requests in the batch based on the attention
284-
backend's needs. For example, some attention backends (namely MLA) may
284+
backend's needs. For example, some attention backends (namely MLA) may
285285
want to separate requests based on if the attention computation will be
286286
compute-bound or memory-bound.
287287
@@ -1360,9 +1360,10 @@ def execute_model(
13601360
scheduler_output.num_scheduled_tokens[req_id])
13611361
next_token_id = req_state.get_token_id(seq_len)
13621362
next_token_ids.append(next_token_id)
1363-
next_token_ids = torch.tensor(next_token_ids,
1364-
dtype=torch.int32,
1365-
device=self.device)
1363+
next_token_ids = async_tensor_h2d(next_token_ids,
1364+
dtype=torch.int32,
1365+
target_device=self.device,
1366+
pin_memory=True)
13661367
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
13671368

13681369
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
@@ -1390,14 +1391,16 @@ def execute_model(
13901391
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
13911392
for i, n in enumerate(num_draft_tokens)
13921393
]
1393-
num_rejected_tokens = torch.tensor(
1394+
num_rejected_tokens_tensor = async_tensor_h2d(
13941395
num_rejected_tokens,
13951396
dtype=torch.int32,
1396-
device=self.device,
1397-
)
1397+
target_device=self.device,
1398+
pin_memory=True)
1399+
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
13981400
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
13991401
eagle_attn_metadata.query_start_loc,
1400-
num_rejected_tokens,
1402+
num_rejected_tokens_tensor,
1403+
num_tokens,
14011404
)
14021405
target_token_ids = self.input_ids[token_indices]
14031406
target_positions = positions[token_indices]
@@ -1408,7 +1411,6 @@ def execute_model(
14081411
target_hidden_states = hidden_states[token_indices]
14091412
target_slot_mapping = eagle_attn_metadata.slot_mapping[
14101413
token_indices]
1411-
14121414
draft_token_ids = self.drafter.propose(
14131415
target_token_ids=target_token_ids,
14141416
target_positions=target_positions,

0 commit comments

Comments
 (0)