Skip to content

Commit d07a3c5

Browse files
committed
address comments
Signed-off-by: qizixi <[email protected]>
1 parent 4ecf150 commit d07a3c5

File tree

2 files changed

+4
-7
lines changed

2 files changed

+4
-7
lines changed

vllm/v1/spec_decode/metadata.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ class SpecDecodeMetadata:
2020
bonus_logits_indices: torch.Tensor
2121
# [num_tokens + batch_size]
2222
logits_indices: torch.Tensor
23-
total_num_scheduled_tokens: int
2423

2524
def __post_init__(self):
2625
self.max_spec_len = max(self.num_draft_tokens)
@@ -59,5 +58,4 @@ def make_dummy(
5958
target_logits_indices=target_logits_indices,
6059
bonus_logits_indices=bonus_logits_indices,
6160
logits_indices=logits_indices,
62-
total_num_scheduled_tokens=num_tokens,
6361
)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@
3232
from vllm.sampling_params import SamplingType
3333
from vllm.sequence import IntermediateTensors
3434
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
35-
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
36-
check_use_alibi, is_pin_memory_available)
35+
GiB_bytes, LayerBlockType, LazyLoader,
36+
async_tensor_h2d, cdiv, check_use_alibi,
37+
is_pin_memory_available)
3738
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
3839
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
3940
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
@@ -895,7 +896,6 @@ def _calc_spec_decode_metadata(
895896
target_logits_indices=target_logits_indices,
896897
bonus_logits_indices=bonus_logits_indices,
897898
logits_indices=logits_indices,
898-
total_num_scheduled_tokens=cu_num_scheduled_tokens[-1],
899899
)
900900
return metadata
901901

@@ -1388,8 +1388,7 @@ def execute_model(
13881388
dtype=torch.int32,
13891389
target_device=self.device,
13901390
pin_memory=True)
1391-
num_tokens = spec_decode_metadata.total_num_scheduled_tokens - \
1392-
sum(num_rejected_tokens)
1391+
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
13931392
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
13941393
eagle_attn_metadata.query_start_loc,
13951394
num_rejected_tokens_tensor,

0 commit comments

Comments
 (0)