Skip to content

Commit 4fb8e32

Browse files
[V1] [5/N] API Server: unify Detokenizer and EngineCore input (#11545)
Signed-off-by: [email protected] <[email protected]>
1 parent 328841d commit 4fb8e32

File tree

6 files changed

+66
-77
lines changed

6 files changed

+66
-77
lines changed

tests/v1/engine/test_detokenizer.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import pytest
44
from transformers import AutoTokenizer
55

6-
from vllm.sampling_params import RequestOutputKind
7-
from vllm.v1.engine import EngineCoreOutput
8-
from vllm.v1.engine.detokenizer import Detokenizer, DetokenizerRequest
6+
from vllm.sampling_params import RequestOutputKind, SamplingParams
7+
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
8+
from vllm.v1.engine.detokenizer import Detokenizer
99

1010
TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
1111
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
@@ -71,16 +71,22 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind):
7171

7272
# Make N requests.
7373
requests = [
74-
DetokenizerRequest(
75-
request_id=f"request-{idx}",
76-
prompt=prompt,
77-
prompt_token_ids=prompt_tokens,
78-
skip_special_tokens=False,
79-
spaces_between_special_tokens=False,
80-
output_kind=request_output_kind,
81-
stop=[],
82-
include_stop_str_in_output=False,
83-
) for idx, (
74+
EngineCoreRequest(request_id=f"request-{idx}",
75+
prompt=prompt,
76+
prompt_token_ids=prompt_tokens,
77+
arrival_time=0,
78+
mm_inputs=None,
79+
mm_hashes=None,
80+
mm_placeholders=None,
81+
eos_token_id=None,
82+
lora_request=None,
83+
sampling_params=SamplingParams(
84+
skip_special_tokens=False,
85+
spaces_between_special_tokens=False,
86+
output_kind=request_output_kind,
87+
stop=[],
88+
include_stop_str_in_output=False))
89+
for idx, (
8490
prompt,
8591
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
8692
]
@@ -133,18 +139,25 @@ def test_stop_string(include_stop_str_in_output: bool):
133139

134140
# Make N requests.
135141
requests = [
136-
DetokenizerRequest(
142+
EngineCoreRequest(
137143
request_id=f"request-{idx}",
138144
prompt=prompt,
139145
prompt_token_ids=prompt_tokens,
140-
skip_special_tokens=False,
141-
spaces_between_special_tokens=False,
142-
output_kind=RequestOutputKind.DELTA,
143-
stop=STOP_STRINGS,
144-
include_stop_str_in_output=include_stop_str_in_output,
145-
) for idx, (
146-
prompt,
147-
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
146+
arrival_time=0,
147+
mm_inputs=None,
148+
mm_hashes=None,
149+
mm_placeholders=None,
150+
eos_token_id=None,
151+
lora_request=None,
152+
sampling_params=SamplingParams(
153+
skip_special_tokens=False,
154+
spaces_between_special_tokens=False,
155+
output_kind=RequestOutputKind.DELTA,
156+
stop=STOP_STRINGS,
157+
include_stop_str_in_output=include_stop_str_in_output,
158+
)) for idx, (
159+
prompt,
160+
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
148161
]
149162

150163
# Add requests to the detokenizer.

vllm/v1/engine/__init__.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,7 @@
66

77
from vllm.lora.request import LoRARequest
88
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
9-
from vllm.sampling_params import RequestOutputKind, SamplingParams
10-
11-
12-
@dataclass
13-
class DetokenizerRequest:
14-
15-
request_id: str
16-
prompt: Optional[str]
17-
prompt_token_ids: List[int]
18-
skip_special_tokens: bool
19-
spaces_between_special_tokens: bool
20-
output_kind: RequestOutputKind
21-
22-
stop: List[str]
23-
include_stop_str_in_output: bool
9+
from vllm.sampling_params import SamplingParams
2410

2511

2612
@dataclass

vllm/v1/engine/async_llm.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,18 @@ async def add_request(
158158
raise ValueError(f"Request id {request_id} already running.")
159159
self.rid_to_queue[request_id] = asyncio.Queue()
160160

161-
# 2) Convert input --> DetokenizerRequest / EngineCoreRequest.
162-
detokenizer_req, engine_core_req = self.processor.process_inputs(
163-
request_id, prompt, params, arrival_time, lora_request,
164-
trace_headers, prompt_adapter_request, priority)
161+
# 2) Convert Input --> Request.
162+
request = self.processor.process_inputs(request_id, prompt, params,
163+
arrival_time, lora_request,
164+
trace_headers,
165+
prompt_adapter_request,
166+
priority)
165167

166168
# 3) Add the request to Detokenizer (this process).
167-
self.detokenizer.add_request(detokenizer_req)
169+
self.detokenizer.add_request(request)
168170

169171
# 4) Add the EngineCoreRequest to EngineCore (separate process).
170-
await self.engine_core.add_request_async(engine_core_req)
172+
await self.engine_core.add_request_async(request)
171173

172174
if self.log_requests:
173175
logger.info("Added request %s.", request_id)

vllm/v1/engine/detokenizer.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from vllm.transformers_utils.detokenizer_utils import (
99
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
1010
from vllm.transformers_utils.tokenizer import get_tokenizer
11-
from vllm.v1.engine import DetokenizerRequest, EngineCoreOutput
11+
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
1212

1313
logger = init_logger(__name__)
1414

@@ -55,19 +55,19 @@ def output_token_ids(self) -> List[int]:
5555
def from_new_request(
5656
cls,
5757
tokenizer: AnyTokenizer,
58-
request: DetokenizerRequest,
58+
request: EngineCoreRequest,
5959
) -> "IncrementalDetokenizer":
6060

6161
tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
6262
tokenizer=tokenizer,
6363
prompt_ids=request.prompt_token_ids,
64-
skip_special_tokens=request.skip_special_tokens,
64+
skip_special_tokens=request.sampling_params.skip_special_tokens,
6565
)
6666

67-
stops = request.stop
67+
stops = request.sampling_params.stop
6868
# Number of chars to hold back when stop strings are to be excluded
6969
# from streamed output.
70-
if stops and not request.include_stop_str_in_output:
70+
if stops and not request.sampling_params.include_stop_str_in_output:
7171
stop_buffer_length = max(len(s) for s in stops) - 1
7272
else:
7373
stop_buffer_length = 0
@@ -79,13 +79,14 @@ def from_new_request(
7979
# NOTE(Nick): could we take ownership of it though?
8080
token_ids=request.prompt_token_ids.copy(),
8181
stop=stops,
82-
include_stop_str_in_output=request.include_stop_str_in_output,
82+
include_stop_str_in_output=request.sampling_params.
83+
include_stop_str_in_output,
8384
prefix_offset=prefix_offset,
8485
read_offset=read_offset,
85-
skip_special_tokens=request.skip_special_tokens,
86-
spaces_between_special_tokens=request.
86+
skip_special_tokens=request.sampling_params.skip_special_tokens,
87+
spaces_between_special_tokens=request.sampling_params.
8788
spaces_between_special_tokens,
88-
output_kind=request.output_kind,
89+
output_kind=request.sampling_params.output_kind,
8990
request_id=request.request_id,
9091
prompt=request.prompt,
9192
prompt_token_ids=request.prompt_token_ids,
@@ -227,7 +228,7 @@ def abort_requests(
227228

228229
def add_request(
229230
self,
230-
request: DetokenizerRequest,
231+
request: EngineCoreRequest,
231232
):
232233
"""Add new request to the Detokenizer."""
233234

vllm/v1/engine/llm_engine.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,17 @@ def add_request(
152152
) -> None:
153153

154154
# 1) Process raw inputs into the request.
155-
detokenizer_req, engine_core_req = self.processor.process_inputs(
156-
request_id, prompt, params, arrival_time, lora_request,
157-
trace_headers, prompt_adapter_request, priority)
155+
request = self.processor.process_inputs(request_id, prompt, params,
156+
arrival_time, lora_request,
157+
trace_headers,
158+
prompt_adapter_request,
159+
priority)
158160

159161
# 2) Add the request to Detokenizer.
160-
self.detokenizer.add_request(detokenizer_req)
162+
self.detokenizer.add_request(request)
161163

162164
# 3) Add the request to EngineCore.
163-
self.engine_core.add_request(engine_core_req)
165+
self.engine_core.add_request(request)
164166

165167
def step(self) -> List[RequestOutput]:
166168

vllm/v1/engine/processor.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from typing import Mapping, Optional, Tuple, Union
2+
from typing import Mapping, Optional, Union
33

44
from vllm.config import CacheConfig, LoRAConfig, ModelConfig
55
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
@@ -13,7 +13,7 @@
1313
from vllm.prompt_adapter.request import PromptAdapterRequest
1414
from vllm.sampling_params import SamplingParams
1515
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
16-
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest
16+
from vllm.v1.engine import EngineCoreRequest
1717
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
1818

1919

@@ -62,7 +62,7 @@ def process_inputs(
6262
trace_headers: Optional[Mapping[str, str]] = None,
6363
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
6464
priority: int = 0,
65-
) -> Tuple[DetokenizerRequest, EngineCoreRequest]:
65+
) -> EngineCoreRequest:
6666

6767
# TODO(woosuk): Support pooling models.
6868
# TODO(woosuk): Check max_logprobs
@@ -123,20 +123,7 @@ def process_inputs(
123123
decoder_inputs.multi_modal_data, mm_hashes,
124124
decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs)
125125

126-
# Make Request for Detokenizer.
127-
detokenizer_request = DetokenizerRequest(
128-
request_id,
129-
decoder_inputs.prompt,
130-
decoder_inputs.prompt_token_ids,
131-
sampling_params.skip_special_tokens,
132-
sampling_params.spaces_between_special_tokens,
133-
sampling_params.output_kind,
134-
sampling_params.stop,
135-
sampling_params.include_stop_str_in_output,
136-
)
137-
138-
# Make Request for EngineCore.
139-
engine_core_request = EngineCoreRequest(
126+
return EngineCoreRequest(
140127
request_id,
141128
decoder_inputs.prompt,
142129
decoder_inputs.prompt_token_ids,
@@ -149,8 +136,6 @@ def process_inputs(
149136
lora_request,
150137
)
151138

152-
return detokenizer_request, engine_core_request
153-
154139
def _validate_model_inputs(self, inputs: ProcessorInputs):
155140
if is_encoder_decoder_inputs(inputs):
156141
# For encoder-decoder multimodal models, the max_prompt_len

0 commit comments

Comments
 (0)