Skip to content

Commit 665cbce

Browse files
authored
Added echo function to OpenAI API server. (#1504)
1 parent 7c60044 commit 665cbce

File tree

2 files changed

+71
-24
lines changed

2 files changed

+71
-24
lines changed

vllm/entrypoints/openai/api_server.py

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -160,27 +160,38 @@ async def show_available_models():
160160
return ModelList(data=model_cards)
161161

162162

163-
def create_logprobs(token_ids: List[int],
164-
id_logprobs: List[Dict[int, float]],
165-
initial_text_offset: int = 0) -> LogProbs:
163+
def create_logprobs(
164+
token_ids: List[int],
165+
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None,
166+
num_output_top_logprobs: Optional[int] = None,
167+
initial_text_offset: int = 0,
168+
) -> LogProbs:
166169
"""Create OpenAI-style logprobs."""
167170
logprobs = LogProbs()
168171
last_token_len = 0
169-
for token_id, id_logprob in zip(token_ids, id_logprobs):
172+
if num_output_top_logprobs:
173+
logprobs.top_logprobs = []
174+
for i, token_id in enumerate(token_ids):
175+
step_top_logprobs = top_logprobs[i]
176+
if step_top_logprobs is not None:
177+
token_logprob = step_top_logprobs[token_id]
178+
else:
179+
token_logprob = None
170180
token = tokenizer.convert_ids_to_tokens(token_id)
171181
logprobs.tokens.append(token)
172-
logprobs.token_logprobs.append(id_logprob[token_id])
182+
logprobs.token_logprobs.append(token_logprob)
173183
if len(logprobs.text_offset) == 0:
174184
logprobs.text_offset.append(initial_text_offset)
175185
else:
176186
logprobs.text_offset.append(logprobs.text_offset[-1] +
177187
last_token_len)
178188
last_token_len = len(token)
179189

180-
logprobs.top_logprobs.append({
181-
tokenizer.convert_ids_to_tokens(i): p
182-
for i, p in id_logprob.items()
183-
})
190+
if num_output_top_logprobs:
191+
logprobs.top_logprobs.append({
192+
tokenizer.convert_ids_to_tokens(i): p
193+
for i, p in step_top_logprobs.items()
194+
} if step_top_logprobs else None)
184195
return logprobs
185196

186197

@@ -371,8 +382,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
371382
for the API specification. This API mimics the OpenAI Completion API.
372383
373384
NOTE: Currently we do not support the following features:
374-
- echo (since the vLLM engine does not currently support
375-
getting the logprobs of prompt tokens)
376385
- suffix (the language models we currently support do not support
377386
suffix)
378387
- logit_bias (to be supported by vLLM engine)
@@ -383,11 +392,8 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
383392
if error_check_ret is not None:
384393
return error_check_ret
385394

386-
if request.echo:
387-
# We do not support echo since the vLLM engine does not
388-
# currently support getting the logprobs of prompt tokens.
389-
return create_error_response(HTTPStatus.BAD_REQUEST,
390-
"echo is not currently supported")
395+
# OpenAI API supports echoing the prompt when max_tokens is 0.
396+
echo_without_generation = request.echo and request.max_tokens == 0
391397

392398
if request.suffix is not None:
393399
# The language models we currently support do not support suffix.
@@ -443,9 +449,11 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
443449
stop=request.stop,
444450
stop_token_ids=request.stop_token_ids,
445451
ignore_eos=request.ignore_eos,
446-
max_tokens=request.max_tokens,
452+
max_tokens=request.max_tokens
453+
if not echo_without_generation else 1,
447454
logprobs=request.logprobs,
448455
use_beam_search=request.use_beam_search,
456+
prompt_logprobs=request.logprobs if request.echo else None,
449457
skip_special_tokens=request.skip_special_tokens,
450458
spaces_between_special_tokens=spaces_between_special_tokens,
451459
)
@@ -495,24 +503,42 @@ def create_stream_response_json(
495503
async def completion_stream_generator() -> AsyncGenerator[str, None]:
496504
previous_texts = [""] * request.n
497505
previous_num_tokens = [0] * request.n
506+
has_echoed = [False] * request.n
498507
async for res in result_generator:
499508
res: RequestOutput
500509
for output in res.outputs:
501510
i = output.index
502511
delta_text = output.text[len(previous_texts[i]):]
512+
token_ids = output.token_ids[previous_num_tokens[i]:]
513+
top_logprobs = output.logprobs[previous_num_tokens[i]:]
514+
offsets = len(previous_texts[i])
515+
if request.echo and not has_echoed[i]:
516+
if not echo_without_generation:
517+
delta_text = res.prompt + delta_text
518+
token_ids = res.prompt_token_ids + token_ids
519+
top_logprobs = res.prompt_logprobs + top_logprobs
520+
else:
521+
delta_text = res.prompt
522+
token_ids = res.prompt_token_ids
523+
top_logprobs = res.prompt_logprobs
524+
has_echoed[i] = True
503525
if request.logprobs is not None:
504526
logprobs = create_logprobs(
505-
output.token_ids[previous_num_tokens[i]:],
506-
output.logprobs[previous_num_tokens[i]:],
507-
len(previous_texts[i]))
527+
token_ids=token_ids,
528+
top_logprobs=top_logprobs,
529+
num_output_top_logprobs=request.logprobs,
530+
initial_text_offset=offsets,
531+
)
508532
else:
509533
logprobs = None
510534
previous_texts[i] = output.text
511535
previous_num_tokens[i] = len(output.token_ids)
536+
finish_reason = output.finish_reason
512537
response_json = create_stream_response_json(
513538
index=i,
514539
text=delta_text,
515540
logprobs=logprobs,
541+
finish_reason=finish_reason,
516542
)
517543
yield f"data: {response_json}\n\n"
518544
if output.finish_reason is not None:
@@ -551,14 +577,36 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
551577
final_res = res
552578
assert final_res is not None
553579
choices = []
580+
prompt_token_ids = final_res.prompt_token_ids
581+
prompt_logprobs = final_res.prompt_logprobs
582+
prompt_text = final_res.prompt
554583
for output in final_res.outputs:
555584
if request.logprobs is not None:
556-
logprobs = create_logprobs(output.token_ids, output.logprobs)
585+
if not echo_without_generation:
586+
token_ids = output.token_ids
587+
top_logprobs = output.logprobs
588+
if request.echo:
589+
token_ids = prompt_token_ids + token_ids
590+
top_logprobs = prompt_logprobs + top_logprobs
591+
else:
592+
token_ids = prompt_token_ids
593+
top_logprobs = prompt_logprobs
594+
logprobs = create_logprobs(
595+
token_ids=token_ids,
596+
top_logprobs=top_logprobs,
597+
num_output_top_logprobs=request.logprobs,
598+
)
557599
else:
558600
logprobs = None
601+
if not echo_without_generation:
602+
output_text = output.text
603+
if request.echo:
604+
output_text = prompt_text + output_text
605+
else:
606+
output_text = prompt_text
559607
choice_data = CompletionResponseChoice(
560608
index=output.index,
561-
text=output.text,
609+
text=output_text,
562610
logprobs=logprobs,
563611
finish_reason=output.finish_reason,
564612
)

vllm/entrypoints/openai/protocol.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,7 @@ class LogProbs(BaseModel):
106106
text_offset: List[int] = Field(default_factory=list)
107107
token_logprobs: List[Optional[float]] = Field(default_factory=list)
108108
tokens: List[str] = Field(default_factory=list)
109-
top_logprobs: List[Optional[Dict[str,
110-
float]]] = Field(default_factory=list)
109+
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None
111110

112111

113112
class CompletionResponseChoice(BaseModel):

0 commit comments

Comments
 (0)