@@ -160,27 +160,38 @@ async def show_available_models():
160
160
return ModelList (data = model_cards )
161
161
162
162
163
- def create_logprobs (token_ids : List [int ],
164
- id_logprobs : List [Dict [int , float ]],
165
- initial_text_offset : int = 0 ) -> LogProbs :
163
+ def create_logprobs (
164
+ token_ids : List [int ],
165
+ top_logprobs : Optional [List [Optional [Dict [int , float ]]]] = None ,
166
+ num_output_top_logprobs : Optional [int ] = None ,
167
+ initial_text_offset : int = 0 ,
168
+ ) -> LogProbs :
166
169
"""Create OpenAI-style logprobs."""
167
170
logprobs = LogProbs ()
168
171
last_token_len = 0
169
- for token_id , id_logprob in zip (token_ids , id_logprobs ):
172
+ if num_output_top_logprobs :
173
+ logprobs .top_logprobs = []
174
+ for i , token_id in enumerate (token_ids ):
175
+ step_top_logprobs = top_logprobs [i ]
176
+ if step_top_logprobs is not None :
177
+ token_logprob = step_top_logprobs [token_id ]
178
+ else :
179
+ token_logprob = None
170
180
token = tokenizer .convert_ids_to_tokens (token_id )
171
181
logprobs .tokens .append (token )
172
- logprobs .token_logprobs .append (id_logprob [ token_id ] )
182
+ logprobs .token_logprobs .append (token_logprob )
173
183
if len (logprobs .text_offset ) == 0 :
174
184
logprobs .text_offset .append (initial_text_offset )
175
185
else :
176
186
logprobs .text_offset .append (logprobs .text_offset [- 1 ] +
177
187
last_token_len )
178
188
last_token_len = len (token )
179
189
180
- logprobs .top_logprobs .append ({
181
- tokenizer .convert_ids_to_tokens (i ): p
182
- for i , p in id_logprob .items ()
183
- })
190
+ if num_output_top_logprobs :
191
+ logprobs .top_logprobs .append ({
192
+ tokenizer .convert_ids_to_tokens (i ): p
193
+ for i , p in step_top_logprobs .items ()
194
+ } if step_top_logprobs else None )
184
195
return logprobs
185
196
186
197
@@ -371,8 +382,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
371
382
for the API specification. This API mimics the OpenAI Completion API.
372
383
373
384
NOTE: Currently we do not support the following features:
374
- - echo (since the vLLM engine does not currently support
375
- getting the logprobs of prompt tokens)
376
385
- suffix (the language models we currently support do not support
377
386
suffix)
378
387
- logit_bias (to be supported by vLLM engine)
@@ -383,11 +392,8 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
383
392
if error_check_ret is not None :
384
393
return error_check_ret
385
394
386
- if request .echo :
387
- # We do not support echo since the vLLM engine does not
388
- # currently support getting the logprobs of prompt tokens.
389
- return create_error_response (HTTPStatus .BAD_REQUEST ,
390
- "echo is not currently supported" )
395
+ # OpenAI API supports echoing the prompt when max_tokens is 0.
396
+ echo_without_generation = request .echo and request .max_tokens == 0
391
397
392
398
if request .suffix is not None :
393
399
# The language models we currently support do not support suffix.
@@ -443,9 +449,11 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
443
449
stop = request .stop ,
444
450
stop_token_ids = request .stop_token_ids ,
445
451
ignore_eos = request .ignore_eos ,
446
- max_tokens = request .max_tokens ,
452
+ max_tokens = request .max_tokens
453
+ if not echo_without_generation else 1 ,
447
454
logprobs = request .logprobs ,
448
455
use_beam_search = request .use_beam_search ,
456
+ prompt_logprobs = request .logprobs if request .echo else None ,
449
457
skip_special_tokens = request .skip_special_tokens ,
450
458
spaces_between_special_tokens = spaces_between_special_tokens ,
451
459
)
@@ -495,24 +503,42 @@ def create_stream_response_json(
495
503
async def completion_stream_generator () -> AsyncGenerator [str , None ]:
496
504
previous_texts = ["" ] * request .n
497
505
previous_num_tokens = [0 ] * request .n
506
+ has_echoed = [False ] * request .n
498
507
async for res in result_generator :
499
508
res : RequestOutput
500
509
for output in res .outputs :
501
510
i = output .index
502
511
delta_text = output .text [len (previous_texts [i ]):]
512
+ token_ids = output .token_ids [previous_num_tokens [i ]:]
513
+ top_logprobs = output .logprobs [previous_num_tokens [i ]:]
514
+ offsets = len (previous_texts [i ])
515
+ if request .echo and not has_echoed [i ]:
516
+ if not echo_without_generation :
517
+ delta_text = res .prompt + delta_text
518
+ token_ids = res .prompt_token_ids + token_ids
519
+ top_logprobs = res .prompt_logprobs + top_logprobs
520
+ else :
521
+ delta_text = res .prompt
522
+ token_ids = res .prompt_token_ids
523
+ top_logprobs = res .prompt_logprobs
524
+ has_echoed [i ] = True
503
525
if request .logprobs is not None :
504
526
logprobs = create_logprobs (
505
- output .token_ids [previous_num_tokens [i ]:],
506
- output .logprobs [previous_num_tokens [i ]:],
507
- len (previous_texts [i ]))
527
+ token_ids = token_ids ,
528
+ top_logprobs = top_logprobs ,
529
+ num_output_top_logprobs = request .logprobs ,
530
+ initial_text_offset = offsets ,
531
+ )
508
532
else :
509
533
logprobs = None
510
534
previous_texts [i ] = output .text
511
535
previous_num_tokens [i ] = len (output .token_ids )
536
+ finish_reason = output .finish_reason
512
537
response_json = create_stream_response_json (
513
538
index = i ,
514
539
text = delta_text ,
515
540
logprobs = logprobs ,
541
+ finish_reason = finish_reason ,
516
542
)
517
543
yield f"data: { response_json } \n \n "
518
544
if output .finish_reason is not None :
@@ -551,14 +577,36 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
551
577
final_res = res
552
578
assert final_res is not None
553
579
choices = []
580
+ prompt_token_ids = final_res .prompt_token_ids
581
+ prompt_logprobs = final_res .prompt_logprobs
582
+ prompt_text = final_res .prompt
554
583
for output in final_res .outputs :
555
584
if request .logprobs is not None :
556
- logprobs = create_logprobs (output .token_ids , output .logprobs )
585
+ if not echo_without_generation :
586
+ token_ids = output .token_ids
587
+ top_logprobs = output .logprobs
588
+ if request .echo :
589
+ token_ids = prompt_token_ids + token_ids
590
+ top_logprobs = prompt_logprobs + top_logprobs
591
+ else :
592
+ token_ids = prompt_token_ids
593
+ top_logprobs = prompt_logprobs
594
+ logprobs = create_logprobs (
595
+ token_ids = token_ids ,
596
+ top_logprobs = top_logprobs ,
597
+ num_output_top_logprobs = request .logprobs ,
598
+ )
557
599
else :
558
600
logprobs = None
601
+ if not echo_without_generation :
602
+ output_text = output .text
603
+ if request .echo :
604
+ output_text = prompt_text + output_text
605
+ else :
606
+ output_text = prompt_text
559
607
choice_data = CompletionResponseChoice (
560
608
index = output .index ,
561
- text = output . text ,
609
+ text = output_text ,
562
610
logprobs = logprobs ,
563
611
finish_reason = output .finish_reason ,
564
612
)
0 commit comments