|
32 | 32 | from vllm.sampling_params import SamplingType
|
33 | 33 | from vllm.sequence import IntermediateTensors
|
34 | 34 | from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
35 |
| - GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, |
36 |
| - check_use_alibi, is_pin_memory_available) |
| 35 | + GiB_bytes, LayerBlockType, LazyLoader, |
| 36 | + async_tensor_h2d, cdiv, check_use_alibi, |
| 37 | + is_pin_memory_available) |
37 | 38 | from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
38 | 39 | from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
39 | 40 | from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
@@ -895,7 +896,6 @@ def _calc_spec_decode_metadata(
|
895 | 896 | target_logits_indices=target_logits_indices,
|
896 | 897 | bonus_logits_indices=bonus_logits_indices,
|
897 | 898 | logits_indices=logits_indices,
|
898 |
| - total_num_scheduled_tokens=cu_num_scheduled_tokens[-1], |
899 | 899 | )
|
900 | 900 | return metadata
|
901 | 901 |
|
@@ -1388,8 +1388,7 @@ def execute_model(
|
1388 | 1388 | dtype=torch.int32,
|
1389 | 1389 | target_device=self.device,
|
1390 | 1390 | pin_memory=True)
|
1391 |
| - num_tokens = spec_decode_metadata.total_num_scheduled_tokens - \ |
1392 |
| - sum(num_rejected_tokens) |
| 1391 | + num_tokens = num_scheduled_tokens - sum(num_rejected_tokens) |
1393 | 1392 | cu_num_tokens, token_indices = self.drafter.prepare_inputs(
|
1394 | 1393 | eagle_attn_metadata.query_start_loc,
|
1395 | 1394 | num_rejected_tokens_tensor,
|
|
0 commit comments