Skip to content

Commit 6fc8059

Browse files
resolved the error steming from me working on older not updated version of main
1 parent 4fab476 commit 6fc8059

File tree

7 files changed

+63
-31
lines changed

7 files changed

+63
-31
lines changed

changes.patch

Whitespace-only changes.

tests/core/test_num_computed_tokens_update.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,6 @@
22

33
import pytest
44

5-
import sys
6-
import os
7-
8-
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
9-
sys.path.append(os.path.dirname(SCRIPT_DIR))
105

116

127
from tests.conftest import VllmRunner

vllm/entrypoints/llm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1191,7 +1191,6 @@ def stop_profile(self) -> None:
11911191
self.llm_engine.stop_profile()
11921192

11931193
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
1194-
print(device)
11951194
return self.llm_engine.reset_prefix_cache(device)
11961195

11971196
def sleep(self, level: int = 1):

vllm/outputs.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,9 @@ def __init__(
137137
self.num_cached_tokens = num_cached_tokens
138138
if hidden_states is not None:
139139
self.hidden_states = hidden_states
140-
#pdb.set_trace()
140+
141+
142+
141143
def add(self, next_output: "RequestOutput") -> None:
142144
"""Merge subsequent RequestOutput into this one"""
143145

@@ -180,7 +182,11 @@ def from_seq_group(
180182
group.finish_seq(seq_group)
181183
if assembled_seq_group is None:
182184
return None
185+
return cls.from_seq_group(assembled_seq_group, use_cache,
186+
seq_id_to_seq_group)
187+
183188

189+
184190
sampling_params = seq_group.sampling_params
185191
if sampling_params is None:
186192
raise ValueError(
@@ -203,6 +209,7 @@ def from_seq_group(
203209
top_n_seqs = seq_group.get_seqs()
204210

205211
# Create the outputs.
212+
206213
# NOTE: We need omit logprobs here explicitly because the sequence
207214
# always has the logprobs of the sampled tokens even if the
208215
# logprobs are not requested.
@@ -228,7 +235,12 @@ def from_seq_group(
228235
if delta:
229236
# Slice logprobs delta if applicable
230237
if output_logprobs:
231-
output_logprobs = output_logprobs[-num_output_tokens:]
238+
# num_output_tokens can be 0 when n > 1 and request finishes
239+
# before the others
240+
if num_output_tokens > 0:
241+
output_logprobs = output_logprobs[-num_output_tokens:]
242+
else:
243+
output_logprobs = None
232244
# Don't include prompt if this is after the first output
233245
# containing decode token ids
234246
if include_prompt and seq.get_output_len() > num_output_tokens:

vllm/v1/engine/output_processor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def __init__(
8686
detokenizer: IncrementalDetokenizer,
8787
max_tokens_param: Optional[int],
8888
arrival_time: float,
89-
queue: Optional[asyncio.Queue[RequestOutput]],
89+
queue: Optional[RequestOutputCollector],
9090
log_stats: bool,
9191
):
9292
self.request_id = request_id
@@ -113,7 +113,7 @@ def from_new_request(
113113
request: EngineCoreRequest,
114114
parent_req: Optional[ParentRequest],
115115
request_index: int,
116-
queue: Optional[asyncio.Queue[RequestOutput]],
116+
queue: Optional[RequestOutputCollector],
117117
log_stats: bool,
118118
) -> "RequestState":
119119
if not request.sampling_params.detokenize:
@@ -155,7 +155,7 @@ def make_request_output(
155155

156156
# In follow up, we will switch to invariant where EngineCore
157157
# does not stream partial prefills.
158-
if not finished and (self.is_prefilling or final_only):
158+
if not finished and final_only:
159159
# Only the final output is required in FINAL_ONLY mode.
160160
return None
161161

@@ -281,7 +281,7 @@ def add_request(
281281
request: EngineCoreRequest,
282282
parent_req: Optional[ParentRequest] = None,
283283
request_index: int = 0,
284-
queue: Optional[asyncio.Queue[RequestOutput]] = None,
284+
queue: Optional[RequestOutputCollector] = None,
285285
) -> None:
286286
request_id = request.request_id
287287
if request_id in self.request_states:
@@ -361,7 +361,7 @@ def process_outputs(
361361
#
362362
# Follow up will aggregate partial prompt logprobs
363363
# in the EngineCore.
364-
req_state.is_prefilling = not new_token_ids
364+
req_state.is_prefilling = False
365365

366366
# 2) Detokenize the token ids into text and perform stop checks.
367367
stop_string = req_state.detokenizer.update(
@@ -379,7 +379,7 @@ def process_outputs(
379379
new_token_ids, finish_reason, stop_reason, hidden_states):
380380
if req_state.queue is not None:
381381
# AsyncLLM: put into queue for handling by generate().
382-
req_state.queue.put_nowait(request_output)
382+
req_state.queue.put(request_output)
383383
else:
384384
# LLMEngine: return list of RequestOutputs.
385385
request_outputs.append(request_output)

vllm/v1/outputs.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,28 @@ def tolists(self):
4040
)
4141

4242

43+
44+
45+
@staticmethod
46+
def empty_cpu(num_positions: int,
47+
num_tokens_per_position: int) -> "LogprobsTensors":
48+
"""Create empty LogprobsTensors on CPU."""
49+
50+
logprob_token_ids = torch.empty(
51+
(num_positions, num_tokens_per_position),
52+
dtype=torch.int32,
53+
device="cpu")
54+
logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32)
55+
selected_token_ranks = torch.empty(num_positions,
56+
dtype=torch.int32,
57+
device="cpu")
58+
return LogprobsTensors(
59+
logprob_token_ids=logprob_token_ids,
60+
logprobs=logprobs,
61+
selected_token_ranks=selected_token_ranks,
62+
)
63+
64+
4365
@dataclass
4466
class SamplerOutput:
4567

vllm/v1/worker/gpu_model_runner.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
if TYPE_CHECKING:
4646
import xgrammar as xgr
4747

48-
from vllm.v1.core.scheduler_output import SchedulerOutput
48+
from vllm.v1.core.sched.output import SchedulerOutput
4949
else:
5050
xgr = LazyLoader("xgr", globals(), "xgrammar")
5151

@@ -127,6 +127,7 @@ def __init__(
127127

128128
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
129129
weakref.proxy(self))
130+
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
130131

131132
# Multi-modal data support
132133
self.input_registry = INPUT_REGISTRY
@@ -150,16 +151,18 @@ def __init__(
150151
self.use_spec_decode = False
151152
if self.speculative_config:
152153
self.use_spec_decode = True
154+
153155
# TODO: find a better way to check if we are using ngram.
154-
assert self.speculative_config.ngram_prompt_lookup_min, \
156+
assert self.speculative_config.method == "ngram", \
155157
"Currently, only ngram spec decode is supported in V1."
156158
if get_pp_group().is_last_rank:
157159
self.drafter = NgramProposer()
158160
# Trigger Numba JIT compilation for N-gram proposer.
159161
# This usually takes less than 1 second.
160162
self.drafter.propose(
161163
np.zeros(1024, dtype=np.int32),
162-
self.speculative_config.ngram_prompt_lookup_min,
164+
self.speculative_config.prompt_lookup_min,
165+
self.speculative_config.prompt_lookup_max,
163166
self.speculative_config.num_speculative_tokens,
164167
)
165168
self.rejection_sampler = RejectionSampler()
@@ -566,10 +569,12 @@ def _prepare_inputs(
566569
non_blocking=True)
567570

568571
# Prepare for cascade attention if needed.
569-
common_prefix_len = self._compute_cascade_attn_prefix_len(
570-
num_scheduled_tokens,
571-
scheduler_output.num_common_prefix_blocks,
572-
)
572+
common_prefix_len = 0
573+
if self.cascade_attn_enabled:
574+
common_prefix_len = self._compute_cascade_attn_prefix_len(
575+
num_scheduled_tokens,
576+
scheduler_output.num_common_prefix_blocks,
577+
)
573578
attn_metadata = self.attn_metadata_builder.build(
574579
num_reqs=num_reqs,
575580
num_actual_tokens=total_num_scheduled_tokens,
@@ -1127,16 +1132,15 @@ def execute_model(
11271132
logprobs=logprobs_lists,
11281133
prompt_logprobs_dict=prompt_logprobs_dict,
11291134
hidden_states=hidden_states)
1130-
else:
1131-
return ModelRunnerOutput(
1132-
req_ids=self.input_batch.req_ids,
1133-
req_id_to_index=self.input_batch.req_id_to_index,
1134-
sampled_token_ids=valid_sampled_token_ids,
1135-
spec_token_ids=spec_token_ids,
1136-
logprobs=logprobs_lists,
1137-
prompt_logprobs_dict=prompt_logprobs_dict,
1138-
)
1139-
1135+
1136+
return ModelRunnerOutput(
1137+
req_ids=self.input_batch.req_ids,
1138+
req_id_to_index=self.input_batch.req_id_to_index,
1139+
sampled_token_ids=valid_sampled_token_ids,
1140+
spec_token_ids=spec_token_ids,
1141+
logprobs=logprobs_lists,
1142+
prompt_logprobs_dict=prompt_logprobs_dict,
1143+
)
11401144
def generate_draft_token_ids(
11411145
self,
11421146
sampled_token_ids: list[list[int]],

0 commit comments

Comments
 (0)