Skip to content

Commit 2133c70

Browse files
benchislettjimpang
authored and
jimpang
committed
[V1][Feature] Enable Speculative Decoding with Structured Outputs (vllm-project#14702)
Signed-off-by: Benjamin Chislett <[email protected]> Signed-off-by: Benjamin Chislett <[email protected]>
1 parent a07a68a commit 2133c70

File tree

9 files changed

+209
-59
lines changed

9 files changed

+209
-59
lines changed

benchmarks/backend_request_func.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ async def async_request_openai_completions(
260260
if request_func_input.model_name else request_func_input.model,
261261
"prompt": request_func_input.prompt,
262262
"temperature": 0.0,
263+
"repetition_penalty": 1.0,
263264
"max_tokens": request_func_input.output_len,
264265
"logprobs": request_func_input.logprobs,
265266
"stream": True,

benchmarks/benchmark_serving_structured_output.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
123123
copy.deepcopy(schema) for _ in range(args.num_prompts)
124124
]
125125
for i in range(len(json_schemas)):
126+
if "properties" not in json_schemas[i]:
127+
json_schemas[i]["properties"] = {}
126128
json_schemas[i]["properties"][
127129
f"__optional_field_{uuid.uuid4()}"] = {
128130
"type":
@@ -134,7 +136,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
134136
json_schemas = [schema] * args.num_prompts
135137

136138
def gen_prompt(index: int):
137-
return f"Generate an example of a user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501
139+
return f"Generate an example of a brief user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501
138140

139141
def get_schema(index: int):
140142
return json_schemas[index % len(json_schemas)]
@@ -231,7 +233,8 @@ def _filter_func(item):
231233
idx -= len_dataset
232234
schema = dataset["schema"][idx]
233235
prompt = tokenizer.apply_chat_template(dataset["prompt"][idx],
234-
tokenize=False)
236+
tokenize=False,
237+
add_generation_prompt=True)
235238
input_len = len(tokenizer(prompt).input_ids)
236239
completion = dataset["completion"][idx]
237240

@@ -849,7 +852,7 @@ def main(args: argparse.Namespace):
849852
'json', 'json-unique', 'grammar', 'regex',
850853
'choice', 'xgrammar_bench'
851854
])
852-
parser.add_argument("--json_schema_path",
855+
parser.add_argument("--json-schema-path",
853856
type=str,
854857
default=None,
855858
help="Path to json schema.")

tests/v1/entrypoints/llm/test_struct_output_generate.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,31 @@
1616
from vllm.platforms import current_platform
1717
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
1818

19+
NGRAM_SPEC_CONFIG = {
20+
"model": "[ngram]",
21+
"num_speculative_tokens": 5,
22+
"prompt_lookup_max": 5,
23+
"prompt_lookup_min": 1,
24+
}
25+
26+
EAGLE_SPEC_CONFIG = {
27+
"method": "eagle",
28+
"model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
29+
"num_speculative_tokens": 5,
30+
}
31+
1932
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
20-
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto"),
21-
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto"),
22-
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral"),
23-
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto"),
33+
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None),
34+
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
35+
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
36+
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
2437
#FIXME: This test is flaky on CI thus disabled
2538
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
39+
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto",
40+
NGRAM_SPEC_CONFIG),
41+
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG),
42+
("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto",
43+
EAGLE_SPEC_CONFIG)
2644
]
2745

2846
PARAMS_MODELS_TOKENIZER_MODE = [
@@ -45,8 +63,9 @@ class CarDescription(BaseModel):
4563

4664

4765
@pytest.mark.skip_global_cleanup
48-
@pytest.mark.parametrize("model_name, guided_decoding_backend, tokenizer_mode",
49-
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE)
66+
@pytest.mark.parametrize(
67+
"model_name, guided_decoding_backend, tokenizer_mode, speculative_config",
68+
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE)
5069
def test_structured_output(
5170
monkeypatch: pytest.MonkeyPatch,
5271
sample_json_schema: dict[str, Any],
@@ -58,6 +77,7 @@ def test_structured_output(
5877
guided_decoding_backend: str,
5978
tokenizer_mode: str,
6079
model_name: str,
80+
speculative_config: dict[str, Any],
6181
):
6282
monkeypatch.setenv("VLLM_USE_V1", "1")
6383

@@ -71,7 +91,8 @@ def test_structured_output(
7191
max_model_len=1024,
7292
guided_decoding_backend=guided_decoding_backend,
7393
guided_decoding_disable_any_whitespace=True,
74-
tokenizer_mode=tokenizer_mode)
94+
tokenizer_mode=tokenizer_mode,
95+
speculative_config=speculative_config)
7596

7697
#
7798
# Test 1: Generate JSON output based on a provided schema

vllm/v1/core/sched/scheduler.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def schedule(self) -> SchedulerOutput:
441441
grammar_bitmask = self.structured_output_manager.grammar_bitmask(
442442
self.requests,
443443
structured_output_request_ids,
444-
len(self.running),
444+
scheduled_spec_decode_tokens,
445445
)
446446
# Construct the scheduler output.
447447
new_reqs_data = [
@@ -682,10 +682,6 @@ def update_from_output(
682682
self.encoder_cache_manager.free_encoder_input(
683683
request, input_id)
684684

685-
# Add newly generated spec token ids to the request.
686-
if spec_token_ids is not None:
687-
request.spec_token_ids = spec_token_ids[req_index]
688-
689685
stopped = False
690686
new_logprobs = None
691687
new_token_ids = generated_token_ids
@@ -717,6 +713,17 @@ def update_from_output(
717713
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
718714
req_id, new_token_ids)
719715

716+
# Add newly generated spec token ids to the request.
717+
if spec_token_ids is not None:
718+
if request.use_structured_output:
719+
metadata = request.structured_output_request
720+
assert metadata is not None and metadata.grammar is not None
721+
# Needs to happen after new_token_ids are accepted.
722+
request.spec_token_ids = metadata.grammar.validate_tokens(
723+
spec_token_ids[req_index])
724+
else:
725+
request.spec_token_ids = spec_token_ids[req_index]
726+
720727
# Get prompt logprobs for this request.
721728
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
722729
if new_token_ids:

vllm/v1/structured_output/__init__.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class StructuredOutputManager:
2727
def __init__(self, vllm_config: VllmConfig):
2828
self.backend: Optional[StructuredOutputBackend] = None
2929
self.vllm_config = vllm_config
30+
3031
self._grammar_bitmask: Optional[torch.Tensor] = None
3132

3233
# The default max_workers if not specified is the number of CPUs * 5,
@@ -80,28 +81,60 @@ def grammar_bitmask(
8081
self,
8182
requests: dict[str, Request],
8283
structured_output_request_ids: dict[str, int],
83-
batch_len: int,
84+
scheduled_spec_decode_tokens: dict[str, list[int]],
8485
) -> Optional[npt.NDArray[np.int32]]:
8586
# Prepare the structured output bitmask for this batch.
8687
if not structured_output_request_ids:
8788
return None
8889

8990
if self._grammar_bitmask is None:
9091
assert self.backend is not None
91-
self._grammar_bitmask = self.backend.allocate_token_bitmask(
92-
self.vllm_config.scheduler_config.max_num_seqs)
93-
94-
# Fill the bitmask using the index of each request equal to its
95-
# position in the batch. Resize the bitmask down to the size of
96-
# the batch.
97-
bitmask_tensor = self._grammar_bitmask
98-
for req_id, batch_index in structured_output_request_ids.items():
92+
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
93+
if self.vllm_config.speculative_config is not None:
94+
max_num_spec_tokens = self.vllm_config.\
95+
speculative_config.num_speculative_tokens
96+
else:
97+
max_num_spec_tokens = 0
98+
99+
# Allocate a bitmask for each token needing to be checked:
100+
# one for each speculative position, and one more for the
101+
# bonus token / non-speculative token.
102+
self._grammar_bitmask = \
103+
self.backend.allocate_token_bitmask(
104+
max_batch_size * (1 + max_num_spec_tokens))
105+
106+
# Generate a batched bitmask for all structured output requests.
107+
# When speculative decoding is enabled, we need to include multiple
108+
# masks for each request, one for each possible bonus token position.
109+
# These are stored inline in the tensor and unpacked by the gpu runner.
110+
cumulative_index = 0
111+
ordered_seq = sorted(structured_output_request_ids.items(),
112+
key=lambda x: x[1])
113+
# NOTE: This outer loop can likely be parallelized to improve
114+
# performance of bitmask generation for large batches.
115+
for req_id, _ in ordered_seq:
99116
request = requests[req_id].structured_output_request
100117
assert request is not None and request.grammar is not None
101-
if not request.grammar.is_terminated():
102-
request.grammar.fill_bitmask(bitmask_tensor, batch_index)
103-
if batch_len < self._grammar_bitmask.shape[0]:
104-
bitmask_tensor = self._grammar_bitmask[:batch_len]
118+
state_advancements = 0
119+
req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
120+
for i, token in enumerate(req_tokens):
121+
if not request.grammar.is_terminated():
122+
request.grammar.fill_bitmask(self._grammar_bitmask,
123+
cumulative_index)
124+
if token is not None:
125+
# In order to generate the correct bitmask for each
126+
# position in the speculative sequence, we advance
127+
# the FSM state for each speculative token and rollback
128+
# to restore the previous state when we are finished.
129+
assert request.grammar.accept_tokens(req_id, [token])
130+
state_advancements += 1
131+
cumulative_index += 1
132+
if state_advancements > 0:
133+
request.grammar.rollback(state_advancements)
134+
135+
bitmask_tensor = self._grammar_bitmask
136+
if cumulative_index < self._grammar_bitmask.shape[0]:
137+
bitmask_tensor = self._grammar_bitmask[:cumulative_index]
105138

106139
# After finishing with the xgrammar operations, we convert to
107140
# np.ndarray, because that is much more efficient for serialization

vllm/v1/structured_output/backend_guidance.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,27 @@ def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
144144

145145
return r
146146

147+
def validate_tokens(self, tokens: list[int]) -> list[int]:
148+
"""Checks if the list of tokens are accepted by the parser in sequence.
149+
Will not advance the parser.
150+
151+
Returns the prefix list of tokens that are accepted by the parser.
152+
"""
153+
if len(tokens) == 0:
154+
return []
155+
if self.ll_matcher.is_stopped():
156+
return []
157+
158+
num_tokens = self.ll_matcher.validate_tokens(tokens)
159+
160+
self.check_error()
161+
162+
return tokens[:num_tokens]
163+
164+
def rollback(self, num_tokens: int) -> None:
165+
self.ll_matcher.rollback(num_tokens)
166+
self.check_error()
167+
147168
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
148169
# this will automatically return [EOS] mask if the matcher is stopped
149170
# or otherwise in an error state

vllm/v1/structured_output/backend_types.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,30 @@ def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
3535
bool: True if the tokens are accepted, False otherwise.
3636
"""
3737

38+
@abstractmethod
39+
def validate_tokens(self, tokens: list[int]) -> list[int]:
40+
"""
41+
Validates the provided tokens against the grammar.
42+
Will not advance the FSM.
43+
44+
Args:
45+
tokens (list[int]): A list of token IDs to validate.
46+
47+
Returns:
48+
list[int]: A list of accepted token IDs. Will be a prefix
49+
of the input tokens, and empty if none are accepted.
50+
"""
51+
52+
@abstractmethod
53+
def rollback(self, num_tokens: int) -> None:
54+
"""
55+
Rolls back the state of the grammar by a specified number of tokens.
56+
Will also revert counters for the number of processed tokens.
57+
58+
Args:
59+
num_tokens (int): The number of tokens to roll back.
60+
"""
61+
3862
@abstractmethod
3963
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
4064
"""

vllm/v1/structured_output/backend_xgrammar.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ def __init__(self, vllm_config: VllmConfig):
4040
self.disable_any_whitespace = \
4141
vllm_config.decoding_config.disable_any_whitespace
4242

43+
self.num_speculative_tokens = 0
44+
if self.vllm_config.speculative_config is not None:
45+
self.num_speculative_tokens = \
46+
self.vllm_config.speculative_config.num_speculative_tokens
47+
4348
tokenizer = tokenizer_group.get_lora_tokenizer(None)
4449
self.vocab_size = vllm_config.model_config.get_vocab_size()
4550
if isinstance(tokenizer, MistralTokenizer):
@@ -118,7 +123,10 @@ def compile_grammar(self, request_type: StructuredOutputOptions,
118123
f"grammar is not of valid supported types. ({request_type!s})")
119124

120125
return XgrammarGrammar(
121-
matcher=xgr.GrammarMatcher(ctx),
126+
matcher=xgr.GrammarMatcher(
127+
ctx,
128+
max_rollback_tokens=self.num_speculative_tokens,
129+
),
122130
vocab_size=self.vocab_size,
123131
ctx=ctx,
124132
)
@@ -136,7 +144,6 @@ class XgrammarGrammar(StructuredOutputGrammar):
136144
# supporting different backends, in the future.
137145
# For now, just xgrammar.
138146
#
139-
# TODO: support max_rollback_tokens
140147
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
141148
# for jump-forward decoding
142149

@@ -163,6 +170,27 @@ def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
163170
self.num_processed_tokens += 1
164171
return True
165172

173+
def validate_tokens(self, tokens: list[int]) -> list[int]:
174+
"""Checks if the list of tokens are accepted by the FSM in sequence.
175+
Will not advance the FSM.
176+
177+
Returns the prefix list of tokens that are accepted by the FSM.
178+
"""
179+
accepted_tokens = []
180+
for token in tokens:
181+
if self.matcher.accept_token(token):
182+
accepted_tokens.append(token)
183+
else:
184+
break
185+
if len(accepted_tokens) > 0:
186+
# Rollback the FSM to the initial state
187+
self.matcher.rollback(len(accepted_tokens))
188+
return accepted_tokens
189+
190+
def rollback(self, num_tokens: int) -> None:
191+
self.matcher.rollback(num_tokens)
192+
self.num_processed_tokens -= num_tokens
193+
166194
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
167195
self.matcher.fill_next_token_bitmask(bitmask, idx)
168196

0 commit comments

Comments
 (0)