Skip to content

[WIP] Echo prompt tokens #833

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,8 @@ def update(
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
# Append a new token to the sequence.
output = seq_outputs[seq.seq_id]
seq.append_token_id(output.output_token, output.logprobs)
seq.append_token_id(output.output_token, output.logprobs,
output.echo)
return scheduled

def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
Expand Down
10 changes: 10 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,16 @@ def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
"""Decodes the sequence outputs."""
for seq_group in seq_groups:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
for echo_logprob in seq.echo_logprobs[-1]:
new_token, new_output_text = detokenize_incrementally(
self.tokenizer,
seq.output_tokens,
echo_logprob.output_token,
skip_special_tokens=True,
)
if new_token is not None:
seq.output_tokens.append(new_token)
seq.output_text = new_output_text
new_token, new_output_text = detokenize_incrementally(
self.tokenizer,
seq.output_tokens,
Expand Down
9 changes: 1 addition & 8 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,6 @@ async def create_completion(raw_request: Request):
for the API specification. This API mimics the OpenAI Completion API.

NOTE: Currently we do not support the following features:
- echo (since the vLLM engine does not currently support
getting the logprobs of prompt tokens)
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported by vLLM engine)
Expand All @@ -368,12 +366,6 @@ async def create_completion(raw_request: Request):
if error_check_ret is not None:
return error_check_ret

if request.echo:
# We do not support echo since the vLLM engine does not
# currently support getting the logprobs of prompt tokens.
return create_error_response(HTTPStatus.BAD_REQUEST,
"echo is not currently supported")

if request.suffix is not None:
# The language models we currently support do not support suffix.
return create_error_response(HTTPStatus.BAD_REQUEST,
Expand Down Expand Up @@ -429,6 +421,7 @@ async def create_completion(raw_request: Request):
max_tokens=request.max_tokens,
logprobs=request.logprobs,
use_beam_search=request.use_beam_search,
echo=request.echo,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
Expand Down
87 changes: 79 additions & 8 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,21 @@ def forward(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
input_ids: torch.Tensor,
input_metadata: InputMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> Dict[int, SequenceOutputs]:
# Calculate prompt tokens logprobs for the first token generation
echo_logits = _logits(embedding, hidden_states, self.vocab_size,
embedding_bias)
echo = _echo(echo_logits, input_ids, input_metadata)

# Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, input_metadata)

# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = gather_from_tensor_model_parallel_region(logits)
# Remove paddings in vocab (if any).
logits = _logits(embedding, hidden_states, self.vocab_size,
embedding_bias)
logits = logits[:, :self.vocab_size]

# Apply presence and frequency penalties.
Expand Down Expand Up @@ -86,7 +89,22 @@ def forward(
logprobs = torch.log(probs)

# Sample the next tokens.
return _sample(probs, logprobs, input_metadata)
return _sample(probs, logprobs, input_metadata, echo)


def _logits(
embedding: torch.Tensor,
hidden_states: torch.Tensor,
vocab_size: int,
embedding_bias: Optional[torch.Tensor] = None,
):
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = gather_from_tensor_model_parallel_region(logits)
# Remove paddings in vocab (if any).
logits = logits[:, :vocab_size]
return logits


def _prune_hidden_states(
Expand Down Expand Up @@ -371,6 +389,7 @@ def _sample(
probs: torch.Tensor,
logprobs: torch.Tensor,
input_metadata: InputMetadata,
echo=None,
) -> Dict[int, SequenceOutputs]:
seq_outputs: Dict[int, SequenceOutputs] = {}

Expand All @@ -395,9 +414,12 @@ def _sample(
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
output_logprobs = next_logprobs.copy()
output_logprobs[next_token_id] = logprob[next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs(seq_id, seq_id,
seq_outputs[seq_id] = SequenceOutputs(seq_id,
seq_id,
next_token_id,
output_logprobs)
output_logprobs,
echo=echo.get(
seq_id, []))
else:
# Generate the next tokens for generation tokens.
prob = probs[idx:idx + len(seq_ids)]
Expand Down Expand Up @@ -433,3 +455,52 @@ def _sample(
)

return seq_outputs


def _echo(
echo_logits: torch.Tensor,
input_ids: torch.Tensor,
input_metadata: InputMetadata,
) -> Dict[int, SequenceOutputs]:
# skip echo logprobs after the first token was generated
if input_metadata.num_generation_tokens:
return {}

echo_probs = torch.softmax(echo_logits, dim=-1, dtype=torch.float)
echo_logprobs = torch.log(echo_probs)

# remove padding
echo_ids = input_ids[:input_metadata.num_valid_tokens]
# remove first token (logprobs None)
target_ids = echo_ids[1:]

target_ids_tensor = target_ids.unsqueeze(0)
echo_logprobs_view = echo_logprobs.view(-1, echo_logprobs.shape[1])
target_ids_logprobs = torch.gather(echo_logprobs_view, 1,
target_ids_tensor)

seq_outputs: Dict[int, SequenceOutputs] = {}
for seq_group in input_metadata.seq_groups:
seq_ids, sampling_params = seq_group
for seq_idx, seq_id in enumerate(seq_ids):
echo_tokens = []
if sampling_params.echo:
# logprobs 0 for first token
echo_tokens.append(
SequenceOutputs(
seq_id,
seq_id,
echo_ids[0].item(),
{echo_ids[0].item(): 0.0},
))
for index, t_logp in enumerate(target_ids_logprobs[seq_idx],
1):
echo_tokens.append(
SequenceOutputs(
seq_id,
seq_id,
echo_ids[index].item(),
{echo_ids[index].item(): t_logp.item()},
))
seq_outputs[seq_id] = echo_tokens
return seq_outputs
2 changes: 1 addition & 1 deletion vllm/model_executor/models/aquila.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def forward(
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
input_ids, input_metadata)
return next_tokens

_column_parallel_weights = [
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def forward(
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
input_ids, input_metadata)
return next_tokens

_column_parallel_weights = [
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def forward(
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
input_ids, input_metadata)
return next_tokens

_column_parallel_weights = [
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def forward(
cache_events,
)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
input_ids, input_metadata)

return next_tokens

Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def forward(
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
input_ids, input_metadata)
return next_tokens

_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def forward(
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
input_ids, input_metadata)
return next_tokens

_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/gpt_j.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@ def forward(
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata, self.lm_head.bias)
input_ids, input_metadata,
self.lm_head.bias)
return next_tokens

_column_parallel_weights = [
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def forward(
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.embed_out.weight, hidden_states,
input_metadata)
input_ids, input_metadata)
return next_tokens

_column_parallel_weights = [
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def forward(
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
input_ids, input_metadata)
return next_tokens

_column_parallel_weights = [
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def forward(
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
input_ids, input_metadata)
return next_tokens

_column_parallel_weights = [
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def forward(
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
input_ids, input_metadata)
return next_tokens

_column_parallel_weights = ["wte.weight", "up_proj.weight", "up_proj.bias"]
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def forward(
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
input_ids, input_metadata)
return next_tokens

_column_parallel_weights = [
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def forward(
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
input_ids, input_metadata)
return next_tokens

_column_parallel_weights = ["wte.weight", "lm_head.weight"]
Expand Down
6 changes: 5 additions & 1 deletion vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class SamplingParams:
tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence.
logprobs: Number of log probabilities to return per output token.
echo: Whether to echo the prompt in the output.
"""

def __init__(
Expand All @@ -56,6 +57,7 @@ def __init__(
ignore_eos: bool = False,
max_tokens: int = 16,
logprobs: Optional[int] = None,
echo: Optional[bool] = False,
) -> None:
self.n = n
self.best_of = best_of if best_of is not None else n
Expand All @@ -74,6 +76,7 @@ def __init__(
self.ignore_eos = ignore_eos
self.max_tokens = max_tokens
self.logprobs = logprobs
self.echo = echo

self._verify_args()
if self.use_beam_search:
Expand Down Expand Up @@ -141,4 +144,5 @@ def __repr__(self) -> str:
f"stop={self.stop}, "
f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, "
f"logprobs={self.logprobs})")
f"logprobs={self.logprobs}, "
f"echo={self.echo})")
Loading