Skip to content

Supporting log probabilities of prompt tokens in both engine and OpenAI API server (aka echo) #959

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 40 commits into from
Closed
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
721c778
A runnable implementation of `echo`.
wanmok Sep 6, 2023
4abbe99
Reformatted.
wanmok Sep 6, 2023
3a9649f
Fixed pylint `R1729` and `R1721`
wanmok Sep 6, 2023
f075e8b
Fixed pylint and allowed `logprobs=0, echo=True`.
wanmok Sep 6, 2023
e419414
Fixed edge case of when to return `null` - consistent with OpenAI.
wanmok Sep 6, 2023
50b30f5
Fixed the issue of using `prompt_token_ids` while not being able to g…
wanmok Sep 6, 2023
69f9e52
Fixed the error of computing logprobs.
wanmok Sep 7, 2023
1070080
Fixed the pylint E1133
wanmok Sep 7, 2023
bd180cb
Fixed the issue of running with distributed workers.
wanmok Sep 7, 2023
46af13a
Set torch default dtype in a context manager (#971)
Yard1 Sep 7, 2023
8309110
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 7, 2023
fdc4e5a
Moved postprocessing to `_process_sequence_group_samples()`.
wanmok Sep 7, 2023
f422556
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 7, 2023
4ede60d
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 7, 2023
6c1a70c
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 8, 2023
0bbf78d
Fixed the issue with the first `top_logprobs` being None.
wanmok Sep 8, 2023
b60817d
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 8, 2023
39f8268
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 11, 2023
061ef6f
Fixed handle the batched case correctly.
wanmok Sep 13, 2023
07240ab
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 13, 2023
9b874f2
Updated to address review comments.
wanmok Sep 18, 2023
bf47666
Merge branch 'main' into add-echo
wanmok Sep 18, 2023
7cab379
Reformatted.
wanmok Sep 18, 2023
73f542c
Added test for `get_prompt_logprobs`
wanmok Sep 18, 2023
d6e8f68
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 19, 2023
8541ae0
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 21, 2023
5b63edf
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 21, 2023
1aab537
Merge branch 'main' into add-echo
wanmok Sep 22, 2023
4b7fd8e
Reformatted sampling_params.py
wanmok Sep 22, 2023
364c7aa
Merge branch 'main' into add-echo
wanmok Sep 23, 2023
c6b8d1f
Reformatted
wanmok Sep 23, 2023
67a5de0
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 24, 2023
17675ec
Merge branch 'main' into add-echo
wanmok Sep 27, 2023
792f857
Merge branch 'main' into add-echo
wanmok Sep 29, 2023
ed52b8e
Merge branch 'vllm-project:main' into add-echo
wanmok Oct 1, 2023
0d643fd
Merge branch 'vllm-project:main' into add-echo
wanmok Oct 4, 2023
18f9efa
Update sampler.py
wanmok Oct 4, 2023
9967c2f
Merge branch 'vllm-project:main' into add-echo
wanmok Oct 9, 2023
a0ebd61
Merged log forward.
wanmok Oct 9, 2023
4121f86
Merge branch 'vllm-project:main' into add-echo
wanmok Oct 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,9 @@ def add_request(
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = self.tokenizer.encode(prompt)
if prompt is None:
prompt = self.tokenizer.decode(prompt_token_ids,
skip_special_tokens=True)

# Create the sequences.
block_size = self.cache_config.block_size
Expand Down Expand Up @@ -355,6 +358,11 @@ def _process_sequence_group_samples(
}
for sample in samples:
parent_child_dict[sample.parent_seq_id].append(sample)
seq_data = seq_group.seqs_dict[sample.parent_seq_id].data
if sample.prompt_logprobs is not None:
seq_data.prompt_logprobs = sample.prompt_logprobs
if sample.prompt_top_logprobs is not None:
seq_data.prompt_top_logprobs = sample.prompt_top_logprobs
# List of (child, parent)
child_seqs: List[Tuple[Sequence, Sequence]] = []

Expand Down
76 changes: 56 additions & 20 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,27 +154,34 @@ async def show_available_models():
return ModelList(data=model_cards)


def create_logprobs(token_ids: List[int],
id_logprobs: List[Dict[int, float]],
initial_text_offset: int = 0) -> LogProbs:
def create_logprobs(
token_ids: List[int],
token_logprobs: List[float],
top_logprobs: Optional[List[Dict[int, float]]] = None,
initial_text_offset: int = 0,
) -> LogProbs:
"""Create OpenAI-style logprobs."""
logprobs = LogProbs()
last_token_len = 0
for token_id, id_logprob in zip(token_ids, id_logprobs):
if top_logprobs:
logprobs.top_logprobs = []
for i, (token_id,
token_logprob) in enumerate(zip(token_ids, token_logprobs)):
token = tokenizer.convert_ids_to_tokens(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(id_logprob[token_id])
logprobs.token_logprobs.append(token_logprob)
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token)

logprobs.top_logprobs.append({
tokenizer.convert_ids_to_tokens(i): p
for i, p in id_logprob.items()
})
if top_logprobs:
logprobs.top_logprobs.append({
tokenizer.convert_ids_to_tokens(i): p
for i, p in top_logprobs[i].items()
} if top_logprobs[i] else None)
return logprobs


Expand Down Expand Up @@ -368,11 +375,8 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
if error_check_ret is not None:
return error_check_ret

if request.echo:
# We do not support echo since the vLLM engine does not
# currently support getting the logprobs of prompt tokens.
return create_error_response(HTTPStatus.BAD_REQUEST,
"echo is not currently supported")
# OpenAI API supports echoing the prompt when max_tokens is 0.
echo_self = request.echo and request.max_tokens == 0

if request.suffix is not None:
# The language models we currently support do not support suffix.
Expand Down Expand Up @@ -426,9 +430,10 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
top_k=request.top_k,
stop=request.stop,
ignore_eos=request.ignore_eos,
max_tokens=request.max_tokens,
max_tokens=request.max_tokens if not echo_self else 1,
logprobs=request.logprobs,
use_beam_search=request.use_beam_search,
echo=request.echo,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
Expand Down Expand Up @@ -476,16 +481,24 @@ def create_stream_response_json(
async def completion_stream_generator() -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
echo_self_ends = [False] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
if echo_self and echo_self_ends[i]:
continue
delta_text = output.text[len(previous_texts[i]):]
if request.logprobs is not None:
logprobs = create_logprobs(
output.token_ids[previous_num_tokens[i]:],
output.logprobs[previous_num_tokens[i]:],
len(previous_texts[i]))
token_ids=output.token_ids[previous_num_tokens[i]:],
token_logprobs=output.
logprobs[previous_num_tokens[i]:],
top_logprobs=output.
top_logprobs[previous_num_tokens[i]:]
if request.logprobs > 0 else None,
initial_text_offset=len(previous_texts[i]),
)
else:
logprobs = None
previous_texts[i] = output.text
Expand All @@ -496,6 +509,10 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
logprobs=logprobs,
)
yield f"data: {response_json}\n\n"
echo_self_ends[i] = echo_self and len(
previous_texts[i]) == len(res.prompt_token_ids)
if echo_self and echo_self_ends[i]:
output.finish_reason = "length"
if output.finish_reason is not None:
logprobs = (LogProbs()
if request.logprobs is not None else None)
Expand Down Expand Up @@ -526,16 +543,35 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res

assert final_res is not None
choices = []
for output in final_res.outputs:
if request.logprobs is not None:
logprobs = create_logprobs(output.token_ids, output.logprobs)
token_ids = (output.token_ids
if not echo_self else output.token_ids[:-1])
token_logprobs = (output.logprobs
if not echo_self else output.logprobs[:-1])
if request.logprobs == 0:
top_logprobs = None
else:
top_logprobs = (output.top_logprobs
if not echo_self else output.top_logprobs[:-1])
logprobs = create_logprobs(
token_ids=token_ids,
token_logprobs=token_logprobs,
top_logprobs=top_logprobs,
)
else:
logprobs = None
if echo_self:
output_text = tokenizer.decode(output.token_ids[:-1],
skip_special_tokens=True)
else:
output_text = output.text
choice_data = CompletionResponseChoice(
index=output.index,
text=output.text,
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
Expand Down
3 changes: 1 addition & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str,
float]]] = Field(default_factory=list)
top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None


class CompletionResponseChoice(BaseModel):
Expand Down
23 changes: 12 additions & 11 deletions vllm/model_executor/input_metadata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Optional

import torch
from xformers.ops import AttentionBias
Expand All @@ -18,25 +18,26 @@ class InputMetadata:
context_lens: the length of attention context for each generation token.
max_context_len: The maximum context length.
block_tables: The block tables. (Seq id -> list of physical block)
echo: Whether to echo the prompt tokens. Defaults to False.
"""

def __init__(
self,
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_data: Dict[int, SequenceData],
prompt_lens: List[int],
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
block_tables: torch.Tensor,
) -> None:
def __init__(self,
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_data: Dict[int, SequenceData],
prompt_lens: List[int],
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
block_tables: torch.Tensor,
echo: Optional[List[bool]] = None) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
self.prompt_lens = prompt_lens
self.slot_mapping = slot_mapping
self.context_lens = context_lens
self.max_context_len = max_context_len
self.block_tables = block_tables
self.echo = echo if echo is not None else [False] * len(seq_groups)

self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens)
Expand Down
133 changes: 123 additions & 10 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceOutputs

LogProbsList = Optional[List[Optional[float]]]
TopLogProbsList = Optional[List[Optional[Dict[int, float]]]]

_SAMPLING_EPS = 1e-5


Expand All @@ -33,23 +36,45 @@ def __init__(self, vocab_size: int) -> None:
super().__init__()
self.vocab_size = vocab_size

def _logits_forward(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = gather_from_tensor_model_parallel_region(logits)
# Remove paddings in vocab (if any).
logits = logits[:, :self.vocab_size]
return logits

def forward(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> SamplerOutput:
# Decide whether to perform `echo`
if (input_metadata.echo is not None and any(input_metadata.echo)
and any(x.logprobs is not None
for _, x in input_metadata.seq_groups)):
echo_results = self._process_echo(
embedding=embedding,
hidden_states=hidden_states,
input_metadata=input_metadata,
embedding_bias=embedding_bias,
)
else:
echo_results = None

# Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, input_metadata)

# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = gather_from_tensor_model_parallel_region(logits)
# Remove paddings in vocab (if any).
logits = logits[:, :self.vocab_size]
logits = self._logits_forward(embedding, hidden_states, embedding_bias)

# Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata)
Expand All @@ -58,8 +83,13 @@ def forward(
input_metadata)
assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0]
logits = _apply_penalties(logits, output_tokens, presence_penalties,
frequency_penalties, self.vocab_size)
logits = _apply_penalties(
logits,
output_tokens,
presence_penalties,
frequency_penalties,
self.vocab_size,
)

# Apply temperature scaling.
temperatures = _get_temperatures(input_metadata)
Expand All @@ -83,10 +113,93 @@ def forward(
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities (before applying top-p and top-k).
logprobs = torch.log(probs)
# Re-compute to ensure numerical stability.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)

# Sample the next tokens.
return _sample(probs, logprobs, input_metadata)
samples = _sample(probs, logprobs, input_metadata)
if echo_results:
for i, ss in enumerate(samples):
if i in echo_results:
logprobs_list, top_logprobs_list = echo_results[i]
for s in ss:
s.prompt_logprobs = logprobs_list
s.prompt_top_logprobs = top_logprobs_list

return samples

def _process_echo(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> Dict[int, Tuple[LogProbsList, TopLogProbsList]]:
prompt_index_list: List[int] = []
seq_data = []
has_prompt_seq_group_ids = []
for i, (x, _) in enumerate(input_metadata.seq_groups):
if len(input_metadata.seq_data[x[0]].output_token_ids) == 0:
seq_data.append(input_metadata.seq_data[x[0]])
has_prompt_seq_group_ids.append(i)
prompt_id_list: List[int] = []
start_idx = 0
for prompt_len, seq_datum in zip(input_metadata.prompt_lens, seq_data):
prompt_index_list.extend(
list(range(start_idx, start_idx + prompt_len - 1)))
start_idx += prompt_len

# Pick the first seq_id
prompt_id_list.extend(seq_datum.prompt_token_ids[1:])

if len(seq_data) == 0:
return

prompt_indices = torch.tensor(
prompt_index_list,
dtype=torch.long,
device=embedding.device,
)
selected_hidden_states = hidden_states[prompt_indices]

prompt_ids = torch.tensor(prompt_id_list,
dtype=torch.long,
device=embedding.device)

# Compute the logits for the prompt tokens.
log_probs = self._logits_forward(embedding, selected_hidden_states,
embedding_bias).log_softmax(dim=-1)
# Log probs used in the prompt
# Shift by 1 to account for the start token
selected_log_probs = log_probs.gather(
dim=-1, index=prompt_ids.unsqueeze(dim=-1)).squeeze(-1).tolist()

start_idx = 0
logprobs_list = [None] * len(seq_data)
top_logprobs_list = [None] * len(seq_data)
echo_results = {}
for i, (prompt_len, seq_datum) in enumerate(
zip(input_metadata.prompt_lens, seq_data)):
num_log_probs = input_metadata.seq_groups[i][1].logprobs
assert num_log_probs is not None
if num_log_probs > 0:
top_log_porbs = torch.topk(log_probs[start_idx:start_idx +
prompt_len],
k=num_log_probs,
dim=-1)
# seq_datum.prompt_top_logprobs = [None] + [
top_logprobs_list[i] = [None] + [
dict(zip(x, y))
for x, y in zip(top_log_porbs.indices.tolist(),
top_log_porbs.values.tolist())
]
# seq_datum.prompt_logprobs = [
logprobs_list[i] = [
None
] + selected_log_probs[start_idx:start_idx + prompt_len]
start_idx += prompt_len
echo_results[i] = (logprobs_list[i], top_logprobs_list[i])
return echo_results


def _prune_hidden_states(
Expand Down
Loading