Skip to content

Commit 0de765e

Browse files
authored
Supports prompt_token_ids in the OpenAI completion API. (#1)
Supports prompt_token_ids in the OpenAI completion API.
1 parent 79af7e9 commit 0de765e

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

vllm/entrypoints/openai/api_server.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,18 @@
33

44
import argparse
55
import asyncio
6-
from http import HTTPStatus
76
import json
87
import time
8+
from http import HTTPStatus
99
from typing import AsyncGenerator, Dict, List, Optional
10-
from packaging import version
1110

1211
import fastapi
12+
import uvicorn
1313
from fastapi import BackgroundTasks, Request
1414
from fastapi.exceptions import RequestValidationError
1515
from fastapi.middleware.cors import CORSMiddleware
1616
from fastapi.responses import JSONResponse, StreamingResponse
17-
import uvicorn
17+
from packaging import version
1818

1919
from vllm.engine.arg_utils import AsyncEngineArgs
2020
from vllm.engine.async_llm_engine import AsyncLLMEngine
@@ -375,17 +375,27 @@ async def create_completion(raw_request: Request):
375375

376376
model_name = request.model
377377
request_id = f"cmpl-{random_uuid()}"
378+
379+
use_token_ids = False
378380
if isinstance(request.prompt, list):
379381
if len(request.prompt) == 0:
380382
return create_error_response(HTTPStatus.BAD_REQUEST,
381383
"please provide at least one prompt")
382-
if len(request.prompt) > 1:
383-
return create_error_response(
384-
HTTPStatus.BAD_REQUEST,
385-
"multiple prompts in a batch is not currently supported")
386-
prompt = request.prompt[0]
384+
first_element = request.prompt[0]
385+
if isinstance(first_element, int):
386+
use_token_ids = True
387+
prompt = request.prompt
388+
elif isinstance(first_element, str) or isinstance(first_element, list):
389+
# TODO(@wanmok): handles multiple prompt case in list[list[int]]
390+
if len(request.prompt) > 1:
391+
return create_error_response(
392+
HTTPStatus.BAD_REQUEST,
393+
"multiple prompts in a batch is not currently supported")
394+
use_token_ids = not isinstance(first_element, str)
395+
prompt = request.prompt[0]
387396
else:
388397
prompt = request.prompt
398+
389399
created_time = int(time.time())
390400
try:
391401
sampling_params = SamplingParams(
@@ -405,7 +415,10 @@ async def create_completion(raw_request: Request):
405415
except ValueError as e:
406416
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
407417

408-
result_generator = engine.generate(prompt, sampling_params, request_id)
418+
if use_token_ids:
419+
result_generator = engine.generate(None, sampling_params, request_id, prompt_token_ids=prompt)
420+
else:
421+
result_generator = engine.generate(prompt, sampling_params, request_id)
409422

410423
# Similar to the OpenAI API, when n != best_of, we do not stream the
411424
# results. In addition, we do not stream the results when use beam search.

vllm/entrypoints/openai/protocol.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ class ChatCompletionRequest(BaseModel):
7474

7575
class CompletionRequest(BaseModel):
7676
model: str
77-
prompt: Union[str, List[str]]
77+
# a string, array of strings, array of tokens, or array of token arrays
78+
prompt: Union[List[int], List[List[int]], str, List[str]]
7879
suffix: Optional[str] = None
7980
max_tokens: Optional[int] = 16
8081
temperature: Optional[float] = 1.0
@@ -99,8 +100,7 @@ class LogProbs(BaseModel):
99100
text_offset: List[int] = Field(default_factory=list)
100101
token_logprobs: List[Optional[float]] = Field(default_factory=list)
101102
tokens: List[str] = Field(default_factory=list)
102-
top_logprobs: List[Optional[Dict[str,
103-
float]]] = Field(default_factory=list)
103+
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
104104

105105

106106
class CompletionResponseChoice(BaseModel):

0 commit comments

Comments
 (0)