diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index e000d955cfc..7be1c5b8993 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -100,8 +100,12 @@ def test_prepare_inputs(): dtype=torch.int32, device=device) + # n1 + n2 + n3 - a - b -c + num_tokens = cu_target_query_lens[-1].item() - num_rejected_tokens.sum( + ).item() + cu_num_tokens, token_indices = EagleProposer.prepare_inputs( - cu_target_query_lens, num_rejected_tokens) + cu_target_query_lens, num_rejected_tokens, num_tokens) assert torch.equal(cu_num_tokens, expected_cu_num_tokens) assert token_indices.shape[0] == expected_cu_num_tokens[-1].item() diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 3926a86ee59..876e1ddd14a 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -271,6 +271,7 @@ def prepare_inputs( cu_target_query_lens: torch.Tensor, # [batch_size] num_rejected_tokens: torch.Tensor, + num_tokens: int, ) -> tuple[torch.Tensor, torch.Tensor]: # cu_target_query_lens: [0, a, a + b, a + b + c] # num_rejected_tokens: [n1, n2, n3] @@ -288,18 +289,13 @@ def prepare_inputs( # [a - n1, b - n2, c - n3] -> # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] - cu_num_tokens = torch.empty_like(cu_target_query_lens) + cu_num_tokens = torch.zeros_like(cu_target_query_lens) torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) - cu_num_tokens[0] = 0 - - # FIXME(woosuk): Avoid synchronization. - num_tokens = cu_num_tokens[-1].item() token_indices = torch.empty( num_tokens, dtype=torch.int32, - device=cu_num_tokens.device, + device=cu_target_query_lens.device, ) - batch_size = num_rejected_tokens.shape[0] BLOCK_SIZE = 1024 prepare_eagle_input_kernel[(batch_size, )]( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 42847e2f8c3..5120495dbb9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -34,8 +34,8 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LazyLoader, cdiv, check_use_alibi, - is_pin_memory_available) + GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, + check_use_alibi, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget @@ -281,7 +281,7 @@ def __init__( def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool: """ Update the order of requests in the batch based on the attention - backend's needs. For example, some attention backends (namely MLA) may + backend's needs. For example, some attention backends (namely MLA) may want to separate requests based on if the attention computation will be compute-bound or memory-bound. @@ -1360,9 +1360,10 @@ def execute_model( scheduler_output.num_scheduled_tokens[req_id]) next_token_id = req_state.get_token_id(seq_len) next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.device) + next_token_ids = async_tensor_h2d(next_token_ids, + dtype=torch.int32, + target_device=self.device, + pin_memory=True) eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name] # NOTE: deepseek_mtp uses MLA which does not have `block_table` @@ -1390,14 +1391,16 @@ def execute_model( n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens) ] - num_rejected_tokens = torch.tensor( + num_rejected_tokens_tensor = async_tensor_h2d( num_rejected_tokens, dtype=torch.int32, - device=self.device, - ) + target_device=self.device, + pin_memory=True) + num_tokens = num_scheduled_tokens - sum(num_rejected_tokens) cu_num_tokens, token_indices = self.drafter.prepare_inputs( eagle_attn_metadata.query_start_loc, - num_rejected_tokens, + num_rejected_tokens_tensor, + num_tokens, ) target_token_ids = self.input_ids[token_indices] target_positions = positions[token_indices] @@ -1408,7 +1411,6 @@ def execute_model( target_hidden_states = hidden_states[token_indices] target_slot_mapping = eagle_attn_metadata.slot_mapping[ token_indices] - draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions,