Skip to content

Added the option of returning hidden states #15434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
0b50102
removed v1
Settheworldonfireiii Mar 31, 2025
54be999
Update vllm/engine/llm_engine.py
Settheworldonfireiii Mar 31, 2025
f1ef852
Update vllm/outputs.py
Settheworldonfireiii Mar 31, 2025
d8fad3d
Update vllm/outputs.py
Settheworldonfireiii Mar 31, 2025
bc20a00
Update vllm/outputs.py
Settheworldonfireiii Mar 31, 2025
808bf9e
Update vllm/engine/llm_engine.py
Settheworldonfireiii Mar 31, 2025
1edd330
added test and customized/restored llm_engine.py
Settheworldonfireiii Mar 31, 2025
92c64ca
fixed none device issue
Settheworldonfireiii Mar 31, 2025
0db7e4e
fixed hidden_states to if hidden_states is not None
Settheworldonfireiii Mar 31, 2025
1b9926e
change 3D tensor to temporary store hidden states in llm engine to a …
Settheworldonfireiii Apr 1, 2025
69e3593
locally passed all stuff
Settheworldonfireiii Apr 1, 2025
ec04c07
locally passed all stuff
Settheworldonfireiii Apr 1, 2025
9ba998f
locally passed all stuff
Settheworldonfireiii Apr 1, 2025
944b438
locally passed all stuff
Settheworldonfireiii Apr 1, 2025
79eba9a
correct yapf and ruff
Settheworldonfireiii Apr 1, 2025
46c29df
fixed ruff
Settheworldonfireiii Apr 1, 2025
0fd0d2a
fixed line 1175 in llm_engine.py
Settheworldonfireiii Apr 2, 2025
9e36cf1
fixed llm engine hidden_states associated with value
Settheworldonfireiii Apr 2, 2025
131ceb2
fixed llm engine hidden_states associated with value
Settheworldonfireiii Apr 2, 2025
dc4f311
fixed llm engine hidden_states associated with value
Settheworldonfireiii Apr 2, 2025
3944594
fixed llm engine hidden_states associated with value
Settheworldonfireiii Apr 2, 2025
1cea671
fixed styling : brace
Settheworldonfireiii Apr 2, 2025
320694c
removed redundant files
Settheworldonfireiii Apr 2, 2025
eda3470
fixed yapf
Settheworldonfireiii Apr 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions tests/entrypoints/llm/test_return_hidden_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch

from vllm import LLM, SamplingParams


@pytest.mark.skip_global_cleanup
def test_return_hidden_states():
model = LLM("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
sampling_params = SamplingParams(skip_special_tokens=False,
return_hidden_states=True)
prompt = "Now, tell me about Aspect's experiment \
related to EPR and quantum physics"

o = model.generate(
prompt,
sampling_params=sampling_params,
)

assert isinstance(o[0].hidden_states, torch.Tensor)
20 changes: 19 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,10 +1102,26 @@ def _process_model_outputs(self,
continue

output: List[SequenceGroupOutput]

if has_multiple_outputs:
output = outputs_by_sequence_group[i]
if self.model_config.task == "generate" and \
output[0].hidden_states is not None:
hidden_states = []
for k in range(len(output)):
hidden_states.append(
outputs_by_sequence_group[i][k].hidden_states)
else:
hidden_states = None
else:
output = [outputs_by_sequence_group[0][i]]
if self.model_config.task == "generate" and \
hasattr(outputs_by_sequence_group[0], "hidden_states") \
and outputs_by_sequence_group[0].hidden_states \
is not None:
hidden_states = outputs_by_sequence_group[0].hidden_states
else:
hidden_states = None

if not is_async:
if self.scheduler_config.is_multi_step:
Expand Down Expand Up @@ -1152,10 +1168,12 @@ def _process_model_outputs(self,
seq_group.maybe_set_first_token_time(now)
if not seq_group.is_prefill():
seq_group.set_last_token_time(now)

request_output = RequestOutputFactory.create(
seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs)
use_cache=self.use_cached_outputs,
hidden_states=hidden_states)
if request_output:
ctx.request_outputs.append(request_output)

Expand Down
19 changes: 13 additions & 6 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
encoder_prompt: Optional[str] = None,
encoder_prompt_token_ids: Optional[list[int]] = None,
num_cached_tokens: Optional[int] = None,
hidden_states: Optional[torch.Tensor] = None,
*,
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
) -> None:
Expand All @@ -133,6 +134,7 @@ def __init__(
self.encoder_prompt = encoder_prompt
self.encoder_prompt_token_ids = encoder_prompt_token_ids
self.num_cached_tokens = num_cached_tokens
self.hidden_states = hidden_states

def add(self, next_output: "RequestOutput") -> None:
"""Merge subsequent RequestOutput into this one"""
Expand Down Expand Up @@ -160,8 +162,11 @@ def add(self, next_output: "RequestOutput") -> None:

@classmethod
def from_seq_group(
cls, seq_group: SequenceGroup, use_cache: bool,
seq_id_to_seq_group: dict[str, SequenceGroupBase]
cls,
seq_group: SequenceGroup,
use_cache: bool,
seq_id_to_seq_group: dict[str, SequenceGroupBase],
hidden_states: Optional[torch.Tensor] = None,
) -> Optional["RequestOutput"]:
finished = seq_group.is_finished()

Expand Down Expand Up @@ -291,7 +296,6 @@ def from_seq_group(
prompt_logprobs = None
finished_time = time.time() if finished else None
seq_group.set_finished_time(finished_time)

init_kwargs = {
"request_id": seq_group.request_id,
"prompt": prompt,
Expand All @@ -304,7 +308,8 @@ def from_seq_group(
"encoder_prompt": encoder_prompt,
"encoder_prompt_token_ids": encoder_prompt_token_ids,
"num_cached_tokens": num_cached_tokens,
"multi_modal_placeholders": seq_group.multi_modal_placeholders
"multi_modal_placeholders": seq_group.multi_modal_placeholders,
"hidden_states": hidden_states,
}

if use_cache:
Expand Down Expand Up @@ -385,12 +390,14 @@ class RequestOutputFactory:
@staticmethod
def create(seq_group: SequenceGroup,
seq_id_to_seq_group: dict[str, SequenceGroupBase],
use_cache: bool = False):
use_cache: bool = False,
hidden_states: Optional[torch.Tensor] = None):
if seq_group.pooled_data is not None:
return PoolingRequestOutput.from_seq_group(seq_group)
else:
return RequestOutput.from_seq_group(seq_group, use_cache,
seq_id_to_seq_group)
seq_id_to_seq_group,
hidden_states)


@dataclass
Expand Down
5 changes: 5 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ class SamplingParams(
allowed_token_ids: If provided, the engine will construct a logits
processor which only retains scores for the given token ids.
Defaults to None.
return_hidden_states: If provided, hidden states of the last attention
block are returned in the output
extra_args: Arbitrary additional args, that can be used by custom
sampling implementations. Not used by any in-tree sampling
implementations.
Expand Down Expand Up @@ -233,6 +235,9 @@ class SamplingParams(
allowed_token_ids: Optional[list[int]] = None
extra_args: Optional[dict[str, Any]] = None

# Output hidden states or not
return_hidden_states: Optional[bool] = None

# Fields used for bad words
bad_words: Optional[list[str]] = None
_bad_words_token_ids: Optional[list[list[int]]] = None
Expand Down
1 change: 1 addition & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,7 @@ class CompletionSequenceGroupOutput(
# Prompt logprob for each prompt query token.
prompt_logprobs: Optional[PromptLogprobs]
step_index: Optional[int] = 0
hidden_states: Optional[torch.Tensor] = None

def __repr__(self) -> str:
return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
Expand Down
13 changes: 13 additions & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1714,6 +1714,19 @@ def execute_model(
# virtual engines share the same kv cache.
virtual_engine = model_input.virtual_engine
previous_hidden_states = kwargs.get("previous_hidden_states")

# overrides self.return_hidden_states that was
# assigned during initialization
# the rationale is giving users the option
# to receive hidden states or not
# from the same model w/o re-init it
if (model_input.sampling_metadata is not None
and hasattr(model_input.sampling_metadata, 'seq_groups')
and model_input.sampling_metadata.seq_groups is not None):
self.return_hidden_states = (
model_input.sampling_metadata.seq_groups[0].sampling_params.
return_hidden_states)

Comment on lines +1718 to +1729
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will not work. you are putting return_hidden_states as a request-level parameter, which means you can get a mixture of a batch, some of them requires token output, while some of them requires hidden state output. This cannot be handled well.

I don't think we will accept this feature. If you really want to use it, you can just change the code in vllm/worker/model_runner.py to write the tensor output to a file, and you just read the file directly.

A similar ask is to get attention masks from vllm, which we will not accept, either.

We might accept them as a tutorial, saying that, this feature will not be supported in vllm, but you want to have it, here is how you can modify vllm's code to achieve it, just for your own usage.

if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
Expand Down