Skip to content

Added echo function to OpenAI API server. #1504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 70 additions & 22 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,27 +160,38 @@ async def show_available_models():
return ModelList(data=model_cards)


def create_logprobs(token_ids: List[int],
id_logprobs: List[Dict[int, float]],
initial_text_offset: int = 0) -> LogProbs:
def create_logprobs(
token_ids: List[int],
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None,
num_output_top_logprobs: Optional[int] = None,
initial_text_offset: int = 0,
) -> LogProbs:
"""Create OpenAI-style logprobs."""
logprobs = LogProbs()
last_token_len = 0
for token_id, id_logprob in zip(token_ids, id_logprobs):
if num_output_top_logprobs:
logprobs.top_logprobs = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is not None:
token_logprob = step_top_logprobs[token_id]
else:
token_logprob = None
token = tokenizer.convert_ids_to_tokens(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(id_logprob[token_id])
logprobs.token_logprobs.append(token_logprob)
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token)

logprobs.top_logprobs.append({
tokenizer.convert_ids_to_tokens(i): p
for i, p in id_logprob.items()
})
if num_output_top_logprobs:
logprobs.top_logprobs.append({
tokenizer.convert_ids_to_tokens(i): p
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)
return logprobs


Expand Down Expand Up @@ -371,8 +382,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
for the API specification. This API mimics the OpenAI Completion API.

NOTE: Currently we do not support the following features:
- echo (since the vLLM engine does not currently support
getting the logprobs of prompt tokens)
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported by vLLM engine)
Expand All @@ -383,11 +392,8 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
if error_check_ret is not None:
return error_check_ret

if request.echo:
# We do not support echo since the vLLM engine does not
# currently support getting the logprobs of prompt tokens.
return create_error_response(HTTPStatus.BAD_REQUEST,
"echo is not currently supported")
# OpenAI API supports echoing the prompt when max_tokens is 0.
echo_without_generation = request.echo and request.max_tokens == 0

if request.suffix is not None:
# The language models we currently support do not support suffix.
Expand Down Expand Up @@ -443,9 +449,11 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
stop=request.stop,
stop_token_ids=request.stop_token_ids,
ignore_eos=request.ignore_eos,
max_tokens=request.max_tokens,
max_tokens=request.max_tokens
if not echo_without_generation else 1,
logprobs=request.logprobs,
use_beam_search=request.use_beam_search,
prompt_logprobs=request.logprobs if request.echo else None,
skip_special_tokens=request.skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
Expand Down Expand Up @@ -495,24 +503,42 @@ def create_stream_response_json(
async def completion_stream_generator() -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
has_echoed = [False] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]):]
token_ids = output.token_ids[previous_num_tokens[i]:]
top_logprobs = output.logprobs[previous_num_tokens[i]:]
offsets = len(previous_texts[i])
if request.echo and not has_echoed[i]:
if not echo_without_generation:
delta_text = res.prompt + delta_text
token_ids = res.prompt_token_ids + token_ids
top_logprobs = res.prompt_logprobs + top_logprobs
else:
delta_text = res.prompt
token_ids = res.prompt_token_ids
top_logprobs = res.prompt_logprobs
has_echoed[i] = True
if request.logprobs is not None:
logprobs = create_logprobs(
output.token_ids[previous_num_tokens[i]:],
output.logprobs[previous_num_tokens[i]:],
len(previous_texts[i]))
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=offsets,
)
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
finish_reason = output.finish_reason
response_json = create_stream_response_json(
index=i,
text=delta_text,
logprobs=logprobs,
finish_reason=finish_reason,
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
Expand Down Expand Up @@ -551,14 +577,36 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
final_res = res
assert final_res is not None
choices = []
prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
for output in final_res.outputs:
if request.logprobs is not None:
logprobs = create_logprobs(output.token_ids, output.logprobs)
if not echo_without_generation:
token_ids = output.token_ids
top_logprobs = output.logprobs
if request.echo:
token_ids = prompt_token_ids + token_ids
top_logprobs = prompt_logprobs + top_logprobs
else:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs
logprobs = create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
if not echo_without_generation:
output_text = output.text
if request.echo:
output_text = prompt_text + output_text
else:
output_text = prompt_text
choice_data = CompletionResponseChoice(
index=output.index,
text=output.text,
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
Expand Down
3 changes: 1 addition & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str,
float]]] = Field(default_factory=list)
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None


class CompletionResponseChoice(BaseModel):
Expand Down