Skip to content

Commit 0b50102

Browse files
removed v1
Signed-off-by: Settheworldonfireiii <[email protected]>
1 parent 9b459ec commit 0b50102

File tree

5 files changed

+90
-24
lines changed

5 files changed

+90
-24
lines changed

vllm/engine/llm_engine.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,10 +1102,22 @@ def _process_model_outputs(self,
11021102
continue
11031103

11041104
output: List[SequenceGroupOutput]
1105+
return_hidden_states = False
11051106
if has_multiple_outputs:
11061107
output = outputs_by_sequence_group[i]
1108+
if self.model_config.task == "generate" and hasattr(
1109+
outputs_by_sequence_group[0][0], 'hidden_states'):
1110+
return_hidden_states = True
1111+
for k in range(len(outputs_by_sequence_group[i])):
1112+
output[k].hidden_states = outputs_by_sequence_group[i][
1113+
k].hidden_states
11071114
else:
11081115
output = [outputs_by_sequence_group[0][i]]
1116+
if self.model_config.task == "generate" and hasattr(
1117+
outputs_by_sequence_group[0], 'hidden_states'):
1118+
return_hidden_states = True
1119+
output[0].hidden_states = outputs_by_sequence_group[
1120+
0].hidden_states
11091121

11101122
if not is_async:
11111123
if self.scheduler_config.is_multi_step:
@@ -1152,10 +1164,17 @@ def _process_model_outputs(self,
11521164
seq_group.maybe_set_first_token_time(now)
11531165
if not seq_group.is_prefill():
11541166
seq_group.set_last_token_time(now)
1155-
request_output = RequestOutputFactory.create(
1156-
seq_group,
1157-
self.seq_id_to_seq_group,
1158-
use_cache=self.use_cached_outputs)
1167+
if return_hidden_states:
1168+
request_output = RequestOutputFactory.create(
1169+
seq_group,
1170+
self.seq_id_to_seq_group,
1171+
use_cache=self.use_cached_outputs,
1172+
hidden_states=output[0].hidden_states)
1173+
else:
1174+
request_output = RequestOutputFactory.create(
1175+
seq_group,
1176+
self.seq_id_to_seq_group,
1177+
use_cache=self.use_cached_outputs)
11591178
if request_output:
11601179
ctx.request_outputs.append(request_output)
11611180

vllm/outputs.py

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def __init__(
118118
encoder_prompt: Optional[str] = None,
119119
encoder_prompt_token_ids: Optional[list[int]] = None,
120120
num_cached_tokens: Optional[int] = None,
121+
hidden_states: Optional[torch.Tensor] = None,
121122
*,
122123
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
123124
) -> None:
@@ -133,6 +134,8 @@ def __init__(
133134
self.encoder_prompt = encoder_prompt
134135
self.encoder_prompt_token_ids = encoder_prompt_token_ids
135136
self.num_cached_tokens = num_cached_tokens
137+
if hidden_states is not None:
138+
self.hidden_states = hidden_states
136139

137140
def add(self, next_output: "RequestOutput") -> None:
138141
"""Merge subsequent RequestOutput into this one"""
@@ -160,8 +163,11 @@ def add(self, next_output: "RequestOutput") -> None:
160163

161164
@classmethod
162165
def from_seq_group(
163-
cls, seq_group: SequenceGroup, use_cache: bool,
164-
seq_id_to_seq_group: dict[str, SequenceGroupBase]
166+
cls,
167+
seq_group: SequenceGroup,
168+
use_cache: bool,
169+
seq_id_to_seq_group: dict[str, SequenceGroupBase],
170+
hidden_states: Optional[torch.Tensor] = None,
165171
) -> Optional["RequestOutput"]:
166172
finished = seq_group.is_finished()
167173

@@ -291,21 +297,37 @@ def from_seq_group(
291297
prompt_logprobs = None
292298
finished_time = time.time() if finished else None
293299
seq_group.set_finished_time(finished_time)
294-
295-
init_kwargs = {
296-
"request_id": seq_group.request_id,
297-
"prompt": prompt,
298-
"prompt_token_ids": prompt_token_ids,
299-
"prompt_logprobs": prompt_logprobs,
300-
"outputs": outputs,
301-
"finished": finished,
302-
"metrics": seq_group.metrics,
303-
"lora_request": seq_group.lora_request,
304-
"encoder_prompt": encoder_prompt,
305-
"encoder_prompt_token_ids": encoder_prompt_token_ids,
306-
"num_cached_tokens": num_cached_tokens,
307-
"multi_modal_placeholders": seq_group.multi_modal_placeholders
308-
}
300+
if hidden_states is not None:
301+
init_kwargs = {
302+
"request_id": seq_group.request_id,
303+
"prompt": prompt,
304+
"prompt_token_ids": prompt_token_ids,
305+
"prompt_logprobs": prompt_logprobs,
306+
"outputs": outputs,
307+
"finished": finished,
308+
"metrics": seq_group.metrics,
309+
"lora_request": seq_group.lora_request,
310+
"encoder_prompt": encoder_prompt,
311+
"encoder_prompt_token_ids": encoder_prompt_token_ids,
312+
"num_cached_tokens": num_cached_tokens,
313+
"multi_modal_placeholders": seq_group.multi_modal_placeholders,
314+
"hidden_states": hidden_states,
315+
}
316+
else:
317+
init_kwargs = {
318+
"request_id": seq_group.request_id,
319+
"prompt": prompt,
320+
"prompt_token_ids": prompt_token_ids,
321+
"prompt_logprobs": prompt_logprobs,
322+
"outputs": outputs,
323+
"finished": finished,
324+
"metrics": seq_group.metrics,
325+
"lora_request": seq_group.lora_request,
326+
"encoder_prompt": encoder_prompt,
327+
"encoder_prompt_token_ids": encoder_prompt_token_ids,
328+
"num_cached_tokens": num_cached_tokens,
329+
"multi_modal_placeholders": seq_group.multi_modal_placeholders,
330+
}
309331

310332
if use_cache:
311333
request_output = seq_group.cached_request_output
@@ -385,12 +407,18 @@ class RequestOutputFactory:
385407
@staticmethod
386408
def create(seq_group: SequenceGroup,
387409
seq_id_to_seq_group: dict[str, SequenceGroupBase],
388-
use_cache: bool = False):
410+
use_cache: bool = False,
411+
hidden_states: Optional[torch.Tensor] = None):
389412
if seq_group.pooled_data is not None:
390413
return PoolingRequestOutput.from_seq_group(seq_group)
391414
else:
392-
return RequestOutput.from_seq_group(seq_group, use_cache,
393-
seq_id_to_seq_group)
415+
if hidden_states is not None:
416+
return RequestOutput.from_seq_group(seq_group, use_cache,
417+
seq_id_to_seq_group,
418+
hidden_states)
419+
else:
420+
return RequestOutput.from_seq_group(seq_group, use_cache,
421+
seq_id_to_seq_group)
394422

395423

396424
@dataclass

vllm/sampling_params.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ class SamplingParams(
186186
allowed_token_ids: If provided, the engine will construct a logits
187187
processor which only retains scores for the given token ids.
188188
Defaults to None.
189+
return_hidden_states: If provided, hidden states of the last attention
190+
block are returned in the output
189191
extra_args: Arbitrary additional args, that can be used by custom
190192
sampling implementations. Not used by any in-tree sampling
191193
implementations.
@@ -233,6 +235,9 @@ class SamplingParams(
233235
allowed_token_ids: Optional[list[int]] = None
234236
extra_args: Optional[dict[str, Any]] = None
235237

238+
# Output hidden states or not
239+
return_hidden_states: Optional[bool] = None
240+
236241
# Fields used for bad words
237242
bad_words: Optional[list[str]] = None
238243
_bad_words_token_ids: Optional[list[list[int]]] = None

vllm/sequence.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,6 +1093,7 @@ class CompletionSequenceGroupOutput(
10931093
# Prompt logprob for each prompt query token.
10941094
prompt_logprobs: Optional[PromptLogprobs]
10951095
step_index: Optional[int] = 0
1096+
hidden_states: Optional[torch.Tensor] = None
10961097

10971098
def __repr__(self) -> str:
10981099
return (f"CompletionSequenceGroupOutput(samples={self.samples}, "

vllm/worker/model_runner.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1714,6 +1714,19 @@ def execute_model(
17141714
# virtual engines share the same kv cache.
17151715
virtual_engine = model_input.virtual_engine
17161716
previous_hidden_states = kwargs.get("previous_hidden_states")
1717+
1718+
# overrides self.return_hidden_states that was
1719+
# assigned during initialization
1720+
# the rationale is giving users the option
1721+
# to receive hidden states or not
1722+
# from the same model w/o re-init it
1723+
if (model_input.sampling_metadata is not None
1724+
and hasattr(model_input.sampling_metadata, 'seq_groups')
1725+
and model_input.sampling_metadata.seq_groups is not None):
1726+
self.return_hidden_states = (
1727+
model_input.sampling_metadata.seq_groups[0].sampling_params.
1728+
return_hidden_states)
1729+
17171730
if prefill_meta is None and decode_meta.use_cuda_graph:
17181731
assert model_input.input_tokens is not None
17191732
graph_batch_size = model_input.input_tokens.shape[0]

0 commit comments

Comments
 (0)