Skip to content

Supporting log probabilities of prompt tokens in both engine and OpenAI API server (aka echo) #959

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 40 commits into from
Closed
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
721c778
A runnable implementation of `echo`.
wanmok Sep 6, 2023
4abbe99
Reformatted.
wanmok Sep 6, 2023
3a9649f
Fixed pylint `R1729` and `R1721`
wanmok Sep 6, 2023
f075e8b
Fixed pylint and allowed `logprobs=0, echo=True`.
wanmok Sep 6, 2023
e419414
Fixed edge case of when to return `null` - consistent with OpenAI.
wanmok Sep 6, 2023
50b30f5
Fixed the issue of using `prompt_token_ids` while not being able to g…
wanmok Sep 6, 2023
69f9e52
Fixed the error of computing logprobs.
wanmok Sep 7, 2023
1070080
Fixed the pylint E1133
wanmok Sep 7, 2023
bd180cb
Fixed the issue of running with distributed workers.
wanmok Sep 7, 2023
46af13a
Set torch default dtype in a context manager (#971)
Yard1 Sep 7, 2023
8309110
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 7, 2023
fdc4e5a
Moved postprocessing to `_process_sequence_group_samples()`.
wanmok Sep 7, 2023
f422556
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 7, 2023
4ede60d
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 7, 2023
6c1a70c
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 8, 2023
0bbf78d
Fixed the issue with the first `top_logprobs` being None.
wanmok Sep 8, 2023
b60817d
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 8, 2023
39f8268
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 11, 2023
061ef6f
Fixed handle the batched case correctly.
wanmok Sep 13, 2023
07240ab
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 13, 2023
9b874f2
Updated to address review comments.
wanmok Sep 18, 2023
bf47666
Merge branch 'main' into add-echo
wanmok Sep 18, 2023
7cab379
Reformatted.
wanmok Sep 18, 2023
73f542c
Added test for `get_prompt_logprobs`
wanmok Sep 18, 2023
d6e8f68
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 19, 2023
8541ae0
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 21, 2023
5b63edf
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 21, 2023
1aab537
Merge branch 'main' into add-echo
wanmok Sep 22, 2023
4b7fd8e
Reformatted sampling_params.py
wanmok Sep 22, 2023
364c7aa
Merge branch 'main' into add-echo
wanmok Sep 23, 2023
c6b8d1f
Reformatted
wanmok Sep 23, 2023
67a5de0
Merge branch 'vllm-project:main' into add-echo
wanmok Sep 24, 2023
17675ec
Merge branch 'main' into add-echo
wanmok Sep 27, 2023
792f857
Merge branch 'main' into add-echo
wanmok Sep 29, 2023
ed52b8e
Merge branch 'vllm-project:main' into add-echo
wanmok Oct 1, 2023
0d643fd
Merge branch 'vllm-project:main' into add-echo
wanmok Oct 4, 2023
18f9efa
Update sampler.py
wanmok Oct 4, 2023
9967c2f
Merge branch 'vllm-project:main' into add-echo
wanmok Oct 9, 2023
a0ebd61
Merged log forward.
wanmok Oct 9, 2023
4121f86
Merge branch 'vllm-project:main' into add-echo
wanmok Oct 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import pytest
import torch
from transformers import AutoModelForCausalLM
from transformers.utils import ModelOutput

from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
Expand Down Expand Up @@ -53,23 +54,29 @@ def __init__(
def generate(
self,
prompts: List[str],
raw_output: bool = False,
**kwargs,
) -> List[Tuple[List[int], str]]:
outputs: List[Tuple[List[int], str]] = []
) -> Union[List[Tuple[List[int], str]], List[ModelOutput]]:
if raw_output:
kwargs["return_dict_in_generate"] = True
outputs: Union[List[Tuple[List[int], str]], List[ModelOutput]] = []
for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
output_ids = self.model.generate(
model_output = self.model.generate(
input_ids.cuda(),
use_cache=True,
**kwargs,
)
output_str = self.tokenizer.batch_decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
output_ids = output_ids.cpu().tolist()
outputs.append((output_ids, output_str))
if raw_output:
outputs.append(model_output)
else:
output_str = self.tokenizer.batch_decode(
model_output,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
output_ids = model_output.cpu().tolist()
outputs.append((output_ids, output_str))
return outputs

def generate_greedy(
Expand Down
100 changes: 100 additions & 0 deletions tests/engine/test_get_prompt_logprobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import pytest
import torch

from vllm import SamplingParams

MODELS = ["facebook/opt-125m"]

TEST_PROMPTS = [
"Hello world",
"Hello world. This is a test.",
"Hello world. This is a test. This is a test.",
"To be or not to be,",
"This is a question.",
"Baltimore is the greatest city in",
]
# Test IDs that have the same prefix.
SAME_PREFIX_TEST_IDS = [0, 1, 2]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_get_prompt_logprobs(
hf_runner,
vllm_runner,
model,
dtype,
):
hf_model = hf_runner(model, dtype=dtype)
vllm_model = vllm_runner(model, dtype=dtype)

# Test whether engine results include prompts.
echo_params = SamplingParams(get_prompt_logprobs=True, max_tokens=5)
echo_results = vllm_model.generate(TEST_PROMPTS,
sampling_params=echo_params)
for orig, (_, echoed) in zip(TEST_PROMPTS, echo_results):
assert orig == echoed[0][:len(orig)]

# Test whether prompt logprobs are included in the results.
echo_logprob_params = SamplingParams(get_prompt_logprobs=True,
max_tokens=5,
logprobs=0,
temperature=0.0)
echo_logprob_results = vllm_model.model.generate(
TEST_PROMPTS, sampling_params=echo_logprob_params)

# This is for the case `logprobs=0` indicating only the chosen tokens.
for result in echo_logprob_results:
assert result.outputs[0].logprobs is not None

# This also ensures that same prompts have the same prefix logprobs.
same_prefix_logprobs = None
for i in SAME_PREFIX_TEST_IDS:
result = echo_logprob_results[i]
prefix_logprobs = result.outputs[0].logprobs[:len(result.
prompt_token_ids)]
if same_prefix_logprobs is None:
same_prefix_logprobs = prefix_logprobs
else:
assert all(x == y
for x, y in zip(same_prefix_logprobs, prefix_logprobs))

# To test whether prompt logprobs are consistent with HF
hf_outputs = hf_model.generate(
TEST_PROMPTS,
raw_output=True,
do_sample=False,
max_new_tokens=5,
output_hidden_states=True,
output_scores=True,
)
hf_logprobs_list = []
for output in hf_outputs:
logits = torch.matmul(
output.hidden_states[0][-1],
hf_model.model.get_output_embeddings().weight.t(),
)[0] # batch_size=1
if hf_model.model.get_output_embeddings().bias is not None:
logits += hf_model.model.get_output_embeddings().bias.unsqueeze(0)
hf_logprobs_list.append(logits.log_softmax(dim=-1))

for vllm_result, hf_result in zip(echo_logprob_results, hf_logprobs_list):
prompt_token_ids = torch.tensor(vllm_result.prompt_token_ids[1:],
dtype=torch.long)
vllm_logprobs = torch.tensor([
vllm_result.outputs[0].logprobs[i + 1][tid]
for i, tid in enumerate(vllm_result.prompt_token_ids[1:])
])
hf_logprobs = hf_result[:-1].cpu()[
torch.arange(prompt_token_ids.shape[0]), prompt_token_ids]
assert vllm_logprobs.shape[0] == hf_logprobs.shape[0]
# This is not super tight due to multiple float point conversions
assert torch.isclose(
vllm_logprobs,
hf_logprobs.float(),
atol=3e-2,
rtol=5e-3,
).all()

del hf_model
del vllm_model
8 changes: 8 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,11 @@ def add_request(
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = self.tokenizer.encode(prompt)
if prompt is None:
# If only `prompt_token_ids` provided, this ensures
# the prompt string presented.
prompt = self.tokenizer.decode(prompt_token_ids,
skip_special_tokens=True)

# Create the sequences.
block_size = self.cache_config.block_size
Expand Down Expand Up @@ -356,6 +361,9 @@ def _process_sequence_group_samples(
}
for sample in samples:
parent_child_dict[sample.parent_seq_id].append(sample)
seq = seq_group.seqs_dict[sample.parent_seq_id]
if sample.prompt_top_logprobs is not None:
seq.prompt_top_logprobs = sample.prompt_top_logprobs
# List of (child, parent)
child_seqs: List[Tuple[Sequence, Sequence]] = []

Expand Down
78 changes: 57 additions & 21 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,27 +156,42 @@ async def show_available_models():
return ModelList(data=model_cards)


def create_logprobs(token_ids: List[int],
id_logprobs: List[Dict[int, float]],
initial_text_offset: int = 0) -> LogProbs:
def create_logprobs(
token_ids: List[int],
top_logprobs: List[Optional[Dict[int, float]]] = None,
num_output_top_logprobs: Optional[int] = None,
initial_text_offset: int = 0,
) -> LogProbs:
"""Create OpenAI-style logprobs."""
logprobs = LogProbs()
last_token_len = 0
for token_id, id_logprob in zip(token_ids, id_logprobs):
if num_output_top_logprobs:
logprobs.top_logprobs = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is not None:
token_logprob = step_top_logprobs[token_id]
else:
token_logprob = None
token = tokenizer.convert_ids_to_tokens(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(id_logprob[token_id])
logprobs.token_logprobs.append(token_logprob)
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token)

logprobs.top_logprobs.append({
tokenizer.convert_ids_to_tokens(i): p
for i, p in id_logprob.items()
})
if num_output_top_logprobs:
logprobs.top_logprobs.append({
tokenizer.convert_ids_to_tokens(i): p
for i, p in step_top_logprobs.items()
# Filter out additional logprobs for the chosen token
# This ensures the same number of top logprobs requested
if not (len(step_top_logprobs) > num_output_top_logprobs
and i == token_id)
} if step_top_logprobs else None)
return logprobs


Expand Down Expand Up @@ -352,7 +367,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
for the API specification. This API mimics the OpenAI Completion API.

NOTE: Currently we do not support the following features:
- echo (since the vLLM engine does not currently support
- get_prompt_logprobs (since the vLLM engine does not currently support
getting the logprobs of prompt tokens)
- suffix (the language models we currently support do not support
suffix)
Expand All @@ -364,11 +379,8 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
if error_check_ret is not None:
return error_check_ret

if request.echo:
# We do not support echo since the vLLM engine does not
# currently support getting the logprobs of prompt tokens.
return create_error_response(HTTPStatus.BAD_REQUEST,
"echo is not currently supported")
# OpenAI API supports echoing the prompt when max_tokens is 0.
echo_self = request.echo and request.max_tokens == 0

if request.suffix is not None:
# The language models we currently support do not support suffix.
Expand Down Expand Up @@ -423,9 +435,10 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
stop=request.stop,
stop_token_ids=request.stop_token_ids,
ignore_eos=request.ignore_eos,
max_tokens=request.max_tokens,
max_tokens=request.max_tokens if not echo_self else 1,
logprobs=request.logprobs,
use_beam_search=request.use_beam_search,
get_prompt_logprobs=request.echo,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
Expand Down Expand Up @@ -470,16 +483,21 @@ def create_stream_response_json(
async def completion_stream_generator() -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
echo_self_ends = [False] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
if echo_self and echo_self_ends[i]:
continue
delta_text = output.text[len(previous_texts[i]):]
if request.logprobs is not None:
logprobs = create_logprobs(
output.token_ids[previous_num_tokens[i]:],
output.logprobs[previous_num_tokens[i]:],
len(previous_texts[i]))
token_ids=output.token_ids[previous_num_tokens[i]:],
top_logprobs=output.logprobs[previous_num_tokens[i]:],
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
)
else:
logprobs = None
previous_texts[i] = output.text
Expand All @@ -490,6 +508,10 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
logprobs=logprobs,
)
yield f"data: {response_json}\n\n"
echo_self_ends[i] = echo_self and len(
previous_texts[i]) == len(res.prompt_token_ids)
if echo_self and echo_self_ends[i]:
output.finish_reason = "length"
if output.finish_reason is not None:
logprobs = (LogProbs()
if request.logprobs is not None else None)
Expand All @@ -516,16 +538,30 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res

assert final_res is not None
choices = []
for output in final_res.outputs:
if request.logprobs is not None:
logprobs = create_logprobs(output.token_ids, output.logprobs)
token_ids = (output.token_ids
if not echo_self else output.token_ids[:-1])
top_logprobs = (output.logprobs
if not echo_self else output.logprobs[:-1])
logprobs = create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
if echo_self:
output_text = tokenizer.decode(output.token_ids[:-1],
skip_special_tokens=True)
else:
output_text = output.text
choice_data = CompletionResponseChoice(
index=output.index,
text=output.text,
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
Expand Down
3 changes: 1 addition & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str,
float]]] = Field(default_factory=list)
top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None


class CompletionResponseChoice(BaseModel):
Expand Down
30 changes: 17 additions & 13 deletions vllm/model_executor/input_metadata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Optional

import torch
from xformers.ops import AttentionBias
Expand All @@ -18,25 +18,28 @@ class InputMetadata:
context_lens: the length of attention context for each generation token.
max_context_len: The maximum context length.
block_tables: The block tables. (Seq id -> list of physical block)
get_prompt_logprobs: Whether to get_prompt_logprobs the prompt tokens.
Defaults to False.
"""

def __init__(
self,
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_data: Dict[int, SequenceData],
prompt_lens: List[int],
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
block_tables: torch.Tensor,
) -> None:
def __init__(self,
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_data: Dict[int, SequenceData],
prompt_lens: List[int],
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
block_tables: torch.Tensor,
get_prompt_logprobs: Optional[List[bool]] = None) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
self.prompt_lens = prompt_lens
self.slot_mapping = slot_mapping
self.context_lens = context_lens
self.max_context_len = max_context_len
self.block_tables = block_tables
self.get_prompt_logprobs = (get_prompt_logprobs if get_prompt_logprobs
is not None else [False] * len(seq_groups))

self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens)
Expand All @@ -63,5 +66,6 @@ def __repr__(self) -> str:
f'context_lens={self.context_lens}, '
f'max_context_len={self.max_context_len}), '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'block_tables={self.block_tables}), '
f'slot_mapping={self.slot_mapping}')
f'block_tables={self.block_tables}, '
f'slot_mapping={self.slot_mapping}, '
f'get_prompt_logprobs={self.get_prompt_logprobs})')
Loading