3
3
4
4
import argparse
5
5
import asyncio
6
- from http import HTTPStatus
7
6
import json
8
7
import time
8
+ from http import HTTPStatus
9
9
from typing import AsyncGenerator , Dict , List , Optional
10
- from packaging import version
11
10
12
11
import fastapi
12
+ import uvicorn
13
13
from fastapi import BackgroundTasks , Request
14
14
from fastapi .exceptions import RequestValidationError
15
15
from fastapi .middleware .cors import CORSMiddleware
16
16
from fastapi .responses import JSONResponse , StreamingResponse
17
- import uvicorn
17
+ from packaging import version
18
18
19
19
from vllm .engine .arg_utils import AsyncEngineArgs
20
20
from vllm .engine .async_llm_engine import AsyncLLMEngine
@@ -375,17 +375,27 @@ async def create_completion(raw_request: Request):
375
375
376
376
model_name = request .model
377
377
request_id = f"cmpl-{ random_uuid ()} "
378
+
379
+ use_token_ids = False
378
380
if isinstance (request .prompt , list ):
379
381
if len (request .prompt ) == 0 :
380
382
return create_error_response (HTTPStatus .BAD_REQUEST ,
381
383
"please provide at least one prompt" )
382
- if len (request .prompt ) > 1 :
383
- return create_error_response (
384
- HTTPStatus .BAD_REQUEST ,
385
- "multiple prompts in a batch is not currently supported" )
386
- prompt = request .prompt [0 ]
384
+ first_element = request .prompt [0 ]
385
+ if isinstance (first_element , int ):
386
+ use_token_ids = True
387
+ prompt = request .prompt
388
+ elif isinstance (first_element , str ) or isinstance (first_element , list ):
389
+ # TODO(@wanmok): handles multiple prompt case in list[list[int]]
390
+ if len (request .prompt ) > 1 :
391
+ return create_error_response (
392
+ HTTPStatus .BAD_REQUEST ,
393
+ "multiple prompts in a batch is not currently supported" )
394
+ use_token_ids = not isinstance (first_element , str )
395
+ prompt = request .prompt [0 ]
387
396
else :
388
397
prompt = request .prompt
398
+
389
399
created_time = int (time .time ())
390
400
try :
391
401
sampling_params = SamplingParams (
@@ -405,7 +415,10 @@ async def create_completion(raw_request: Request):
405
415
except ValueError as e :
406
416
return create_error_response (HTTPStatus .BAD_REQUEST , str (e ))
407
417
408
- result_generator = engine .generate (prompt , sampling_params , request_id )
418
+ if use_token_ids :
419
+ result_generator = engine .generate (None , sampling_params , request_id , prompt_token_ids = prompt )
420
+ else :
421
+ result_generator = engine .generate (prompt , sampling_params , request_id )
409
422
410
423
# Similar to the OpenAI API, when n != best_of, we do not stream the
411
424
# results. In addition, we do not stream the results when use beam search.
0 commit comments