-
-
Notifications
You must be signed in to change notification settings - Fork 7.7k
[V1][Core] Support for Structured Outputs #12388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
WoosukKwon
merged 184 commits into
vllm-project:main
from
aarnphm:v1/structured-decoding
Mar 7, 2025
Merged
Changes from 2 commits
Commits
Show all changes
184 commits
Select commit
Hold shift + click to select a range
d719c93
feat: initial guided decoding implementation on scheduler
aarnphm 36bc041
chore: --wip--
aarnphm 39068c8
chore: remove lazy loader
aarnphm 2bb535e
fix: update types and attach bitmask to requests
aarnphm 420f52f
chore: --wip--
aarnphm a5e9874
merge: branch 'main' of github.com:vllm-project/vllm into v1/structur…
aarnphm 75e8fb4
merge: branch 'main' of github.com:vllm-project/vllm into v1/structur…
aarnphm 9daf140
chore: --wip-- cleanup
aarnphm 299ea58
merge: branch 'main' of github.com:vllm-project/vllm into v1/structur…
aarnphm 15a4547
feat: base implementation
aarnphm 49f7b96
fix: update the states within the scheduler
aarnphm cd357e5
[CI/Build] Ignore ruff warning up007
russellb 9a7b081
Resolve ruff errors
russellb 2e43e04
chore: manage requests within manager class
aarnphm ccde524
Drop grammar getter/setter on Request
russellb 1587d34
mypy: Fix return type of GPUModelRunner._prepare_inputs()
russellb 227cc7f
Resolve remaining mypy warnings
russellb c0b235d
Finish getting pre-commit to pass
russellb 49fdce0
Updat michael's suggestions
aarnphm e9a2304
chore: update according to Michael's review
aarnphm f6720a8
Merge remote-tracking branch 'origin/main' into v1/structured-decoding
russellb 872c66f
chore: simplify cache implementations
aarnphm a8a2f27
Changes to get a test request working
russellb 3fda148
Resolve mypy error in request
russellb d7a64eb
chore: remove debug print
aarnphm 34c08ac
Enable some v1 structured output tests
russellb 3b736ce
Validate structured output backend for v1
russellb 0bffe39
Merge remote-tracking branch 'origin/main' into v1/structured-decoding
russellb 9f73ec9
Merge branch 'main' into v1/structured-decoding
aarnphm 1a258fe
wip fixes for bitmask initialization and communication
russellb 10f01f5
Clean up some remnants of inaccurate merge conflict resolution
russellb a6b07d1
fix: correctly use bitmask batch-wise
aarnphm 7f255f0
fix: correct types
aarnphm 9ab107f
chore: validate from decoding_config -> per request
aarnphm 8d6bd3b
chore: passing vocab_size
aarnphm fcb0e85
chore: comment out 0.1.13 features
aarnphm 3402b2a
Merge branch 'main' into v1/structured-decoding
aarnphm e6038f8
Resize bitmask to match the current batch size
russellb 9830899
set any_whitespace=False for json schema + xgrammar
russellb cebe281
--wip--: debugging fsm apply
aarnphm 862c093
fix: make sure to reset the FSM once we _free_request
aarnphm 0df21ee
merge: branch 'main' of github.com:vllm-project/vllm into v1/structur…
aarnphm 0fc85e3
revert: apply grammar bitmask from update states
aarnphm d95d1d7
merge: branch 'main' of github.com:vllm-project/vllm into v1/structur…
aarnphm 62f8025
Merge remote-tracking branch 'origin/main' into v1/structured-decoding
russellb 6a372ea
Revert changes to v0 guided decoding tests
russellb a43afca
create v1 tests_guided_generate for llm entrypoint
russellb fb40918
Drop unused Scheduler.guided_decoding_requests
russellb b8e016c
Allow grammar compilation to complete
russellb c63ca92
Remove some dead committed
russellb 074b65d
Fix index calculation for guided requests in a batch
russellb 727dab0
Make guided decoding manager more thread-safe
russellb adb50ff
chore: remove prefilled check
aarnphm 5b818f9
Merge remote-tracking branch 'origin/main' into v1/structured-decoding
russellb c85408a
Re-enable line length checks in ruff
russellb b34e4a7
Fix a yapf error in main, will be fixed by #13772
russellb 0f2a97f
Merge remote-tracking branch 'origin/main' into v1/structured-decoding
russellb aabe98b
Prepare the bitmask on the scheduler side instead of gpu worker
russellb 8895e19
tests: make sample jsonschema xgrammar compatible
russellb 470b677
Detect unsupported jsonschema features for xgrammar
russellb 42fe5f8
Make bitmask allocation synchronous
russellb ada4790
Fix compat with TP > 1
russellb 331a7ff
Make pre-commit happy again
russellb 0984379
chore: remove reset_bitmask after every steps
aarnphm 9b62eef
revert: update whitespace
aarnphm 2f756e5
Add tests/v1/guided_decoding/test_utils.py
russellb 72adc63
Merge remote-tracking branch 'origin/main' into v1/structured-decoding
russellb 1be1709
add v1 structured output regex test case
russellb 0128aff
Restore some code lost in a merge from main
russellb 9cc90ff
Validate schema is supoprted before sending to threadpool
russellb 3a8f955
chore: remove unused code
aarnphm e772efa
fix: correct typo
aarnphm 64a2ecf
chore(scheduler): simplify check for use_guided_decoding
aarnphm e8f47f3
Move guided decode validation to the engine core_client
russellb f3f7d51
test for expected behavior of a choice guided decode request
russellb 9582f8c
Validate jsonschema features for both str and dict cases
russellb acd5ae0
Test for expected behavior of a request with unsupported jsonschema f…
russellb 4c674ae
Correctly differentiate between jsonschema and json object requests
russellb 1b40882
Test for correct json object (no schema) request behavior
russellb 4f551f4
Add test for a request using an EBNF style grammar
russellb d132d72
Validate that EBNF grammar can be parsed during early validation
russellb b994230
Test for expected behavior of an invalid grammar
russellb 3cc6437
Add support and test coverage for lark style grammars
russellb 95be24b
Add support and tests for choice based guided decoding
russellb 9d1fe71
feat: spec decode compatibility [-------------]
aarnphm 83a5277
fix: correct lock the matcher for both rollback and advance
aarnphm d02e11a
chore: only rollback if there are more than zero processed tokens
aarnphm c64daa7
fix: correctly free requests based on accepted tokens
aarnphm ad05fe8
Account for differences in scheduler and gpu worker batch ordering
russellb 7cf6326
Skip non-guided-decode requests when assembling reordered bitmask
russellb 84bbae1
revert: remove rollback check for now, only advance 1 token
aarnphm c10eb6a
Fix accidental re-use of cached grammar matcher
russellb 0518b70
Use the correct indices for the logits bitmask
russellb 5f23e8b
Update vllm/v1/core/scheduler_output.py
mgoin deb9b36
Apply suggestions from Russell
aarnphm 4bcee6c
chore: update requests to remove unused function
aarnphm 7aea044
merge: branch 'main' of github.com:vllm-project/vllm into v1/structur…
aarnphm 3b49e8e
chore: address comments and renaming for clarity
aarnphm 2a94f9c
Merge remote-tracking branch 'origin/main' into v1/structured-decoding
russellb 5600a30
Move validation to process_inputs()
russellb 2097d41
Reject spec decode + structured output requests
russellb d96c3ff
chore: update max_rollback_tokens to matcher
aarnphm 9cf1a2c
chore: kw_only not available for 3.9
aarnphm d1f7e8e
Revert a change that was a bad merge from main
russellb 11d4483
Limit grammar compilation threadpool size
russellb 9f47de1
benchmarks: filter unsupported jsonschemas from xgrammar_bench dataset
russellb bf5b844
chore: reduce code change for annotations type
aarnphm 458a986
Merge remote-tracking branch 'origin/main' into v1/structured-decoding
russellb f12b8a4
benchmarks: Make first request honor guided-decoding-ratio
russellb e309361
perf: using two deque to always pop lists and preserved order
aarnphm 05d930b
chore: address comments
aarnphm b18c39d
revert: minimal diff
aarnphm e8c3de2
revert: incorrect comments should be removed
aarnphm ec76e7e
Don't overwrite scheduler_output.grammar_bitmask
russellb f9e44b5
Use np.ndarray as the type between scheduler and workers
russellb df9c727
chore: add a shell script to run gallery of benchmark
aarnphm 93a60d8
chore: add options to customize port
aarnphm bd27358
fix: make sure to save results as json
aarnphm cf43e97
chore: ignore nested JSON
aarnphm 1377886
fix: allow set ratio
aarnphm a9033bc
chore: address woosuk comments
aarnphm 75162be
fix: annotations casting
aarnphm b64630a
merge: branch 'main' of github.com:vllm-project/vllm into v1/structur…
aarnphm 959b986
perf: use extendleft to put back skipped requests
aarnphm 9b2e37a
Don't track guided request if it is not scheduled
russellb 58c4fe6
Retain priority order among requests waiting for FSM compilation
russellb 2df32a8
Encapsulate logic for filling in the grammar bitmask
russellb 4d08950
Move more code from scheduler to guided_decoding_manager
russellb 6d8b9aa
Move numpy imports outside of TYPE_CHECKING block
russellb 2a3367f
Remove blank line as requested by Woosuk
russellb a74e737
Put guided decode type constants in all caps
russellb 68d689c
Drop reference to rollback tokens, we don't support rollback
russellb e7d8fd6
Remove thin apply_bitmask() wrapper
russellb 8f63db3
Avoid deprecated type annotations for built-in types
russellb 51bdf22
Add docstring for accept_tokens() method
russellb fca387a
Drop an unnecessary todo comment
russellb 5114a03
Remove unused rollback code
russellb 6369660
Drop unused return value from populate_cache()
russellb 4b7add8
Remove an extra dict lookup from the grammar cache
russellb 11c252f
Make Request.use_guided_decoding a cached property
russellb 6d99518
Remove an extra dict lookup in the GPU model runner
russellb 4311643
chore: use fixture for test cases
aarnphm 0ac6885
Factor out code for applying bitmask into its own method
russellb 7482c97
Delay converting ndarray to tensor until necessary
russellb 9de55e5
Remove unnecessary dataclass attributes from Grammar
russellb 89c741f
chore: update notes for LRU cache
aarnphm 949f644
chore: update naming to clearer name
aarnphm d4a59d5
Fix tests/v1/core/test_scheduler.py to pass again
russellb 25534d4
Only increment num_processed_tokens if token is accepted by xgrammar
russellb 56dabf9
Remove unnecessary continue pointed out in review
russellb 5dd39a4
Remove unnecessary and expensive setup_grammars() loop
russellb b2e3c38
Remove unnecessary requests cache in guided decoding manager
russellb f564111
v1: standardize on structured output based naming
russellb 4544cfb
Merge remote-tracking branch 'origin/main' into v1/structured-decoding
russellb e71e4cd
Fix benchmark script compat with V1
russellb 96424e0
refactor: move to separate request objects
aarnphm 9e392fe
fix: gated lazy import for RequestStatus
aarnphm e2553e8
fix: only construct StructOutputRequest from given core
aarnphm 470f930
chore: rename to all v1 struct outputs
aarnphm cb4b29a
chore: cleanup tests format
aarnphm 0ed7dea
chore: update notes on this edge case for chunked prefill
aarnphm 80fdb3c
Apply Nick's suggestion for correct types
aarnphm d4dd7aa
chore: renaming to more clearer with StructuredOutput
aarnphm 5f8c3a7
chore: cleanup to use tuple
aarnphm d34c29f
chore: cleanup format
aarnphm b2ec4de
merge: branch 'main' of github.com:vllm-project/vllm into v1/structur…
aarnphm 8243681
fix: push the queue up
aarnphm 697e119
chore: remove cached_property
aarnphm e61518e
chore: update format
aarnphm 170902a
chore: add CODEOWNERS
aarnphm 098a864
revert: remove unused params
aarnphm 0437876
Merge remote-tracking branch 'origin/main' into v1/structured-decoding
russellb 56bba73
Rename files and variables to use full structured_output name
russellb 0ae0785
Fix some missed files in my last rename commit
russellb ed77db1
Lazily import xgrammar because it initializes cuda as a side effect
russellb f769d39
chore: lazy import modules with type hint
aarnphm 380d922
chore: lazy load np
aarnphm eeeca39
chore: finalize rename script
aarnphm 6ca6b8b
chore: fix missing variables
aarnphm a259eca
chore: gated typing imports
aarnphm 7556160
Merge remote-tracking branch 'origin/main' into v1/structured-decoding
russellb 515172a
chore: update test state for gpu_model_runner
aarnphm 915f2a3
chore: enable structured outputs tests
aarnphm dbb2024
merge: branch 'main' of github.com:vllm-project/vllm into v1/structur…
aarnphm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
aarnphm marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
from __future__ import annotations | ||
|
||
import copy | ||
import threading | ||
from abc import ABC, abstractmethod | ||
from concurrent.futures import ThreadPoolExecutor | ||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, get_args | ||
|
||
from transformers import PreTrainedTokenizer | ||
|
||
from vllm.config import ModelConfig | ||
from vllm.logger import init_logger | ||
from vllm.utils import LazyLoader | ||
from vllm.v1.request import GuidedDecodingKey, Request, RequestStatus | ||
|
||
from .grammar import Grammar | ||
|
||
if TYPE_CHECKING: | ||
import xgrammar as xgr | ||
from transformers import PreTrainedTokenizer | ||
from typing_extensions import LiteralString | ||
|
||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup | ||
|
||
from .grammar import XGrammar | ||
else: | ||
xgr = LazyLoader("xgr", globals(), "xgrammar") | ||
|
||
logger = init_logger(__name__) | ||
|
||
__all__ = ["Grammar", "GuidedDecodingManager"] | ||
|
||
|
||
@dataclass | ||
class GrammarCache: | ||
value: Grammar | None | ||
event: threading.Event | ||
|
||
|
||
T = TypeVar("T", bound=str) | ||
|
||
|
||
class GuidedDecodingManager(ABC, Generic[T]): | ||
|
||
@abstractmethod | ||
def initialize_cache(self, key: GuidedDecodingKey) -> Grammar: | ||
... | ||
|
||
def flush(self): | ||
with self._lock: | ||
self.grammar_cache.clear() | ||
|
||
def cache(self, request: Request): | ||
aarnphm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _executor_loop(request: Request): | ||
key = request.guided_decoding_key | ||
with self._lock: | ||
cache_hit = False | ||
if key in self.grammar_cache: | ||
cache_hit, entry = True, self.grammar_cache[key] | ||
else: | ||
entry = GrammarCache(None, threading.Event()) | ||
self.grammar_cache[key] = entry | ||
|
||
if cache_hit: | ||
entry.event.wait() | ||
else: | ||
entry.value = self.initialize_cache(key) | ||
entry.event.set() | ||
return copy.copy(entry.value) if entry.value else None | ||
|
||
return self.executor.submit(_executor_loop, request) | ||
|
||
def get(self, request: Request): | ||
with self._lock: | ||
entry = self.grammar_cache.get(request.guided_decoding_key) | ||
if entry is None or not entry.event.is_set(): return None | ||
return copy.copy(entry.value) if entry.value else None | ||
|
||
def collect(self, request: Request): | ||
if not request.use_guided_decoding: return False | ||
request.grammar = self.get(request) | ||
if not request.grammar: | ||
request.grammar = self.cache(request) | ||
request.status = RequestStatus.WAITING_FOR_FSM | ||
return True | ||
return False | ||
|
||
@classmethod | ||
def from_backend(cls, | ||
backend: LiteralString = "xgrammar", | ||
/, | ||
*, | ||
tokenizer_group: BaseTokenizerGroup, | ||
model_config: ModelConfig) -> GuidedDecodingManager[T]: | ||
manager_cls = cls._registry.get(backend) | ||
if manager_cls is None: | ||
raise ValueError( | ||
f"Backend '{backend}' not found in registry. Available backends: {list(cls._registry)}" | ||
) | ||
return manager_cls(tokenizer_group=tokenizer_group, | ||
model_config=model_config) | ||
|
||
_registry: dict[str, type[GuidedDecodingManager[T]]] = {} | ||
_backend: T | ||
|
||
def __init__(self, *, tokenizer_group: BaseTokenizerGroup, | ||
model_config: ModelConfig): | ||
self.model_config = model_config | ||
self.tokenizer = tokenizer_group.get_lora_tokenizer(None) | ||
self.grammar_cache: dict[GuidedDecodingKey, GrammarCache] = {} | ||
self.executor = ThreadPoolExecutor() | ||
self._lock = threading.Lock() | ||
|
||
def __init_subclass__(cls, **kwargs: Any): | ||
if not hasattr(cls, '__orig_bases__'): | ||
raise TypeError( | ||
f"{cls.__qualname__} must be subclass of GuidedDecodingManager" | ||
) | ||
|
||
backend = None | ||
for base in cls.__orig_bases__: | ||
if (origin := get_args(base)) and issubclass( | ||
base.__origin__, GuidedDecodingManager): | ||
backend = get_args(origin[0])[0] | ||
break | ||
|
||
if backend is None: | ||
raise TypeError( | ||
f"Class {cls.__qualname__} must specify backend as a Literal type" | ||
) | ||
|
||
if backend in cls._registry: | ||
name = cls._registry[backend].__qualname__ | ||
raise ValueError( | ||
f"Backend '{backend}' is already registered to {name}") | ||
|
||
# Set the backend value from the Literal type | ||
cls._backend = backend | ||
cls._registry[backend] = cls | ||
|
||
|
||
class XGrammarManager(GuidedDecodingManager[Literal["xgrammar"]]): | ||
# cache GrammarCompiler instances based on given tokenizer | ||
_compiler_cache: dict[str, xgr.GrammarCompiler] = {} | ||
_compiler: xgr.GrammarCompiler | None = None | ||
aarnphm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def initialize_cache(self, key: GuidedDecodingKey) -> XGrammar: | ||
request_type, grammar_spec = key | ||
compiler = XGrammarManager.get_compiler(self.tokenizer) | ||
if request_type == "json": | ||
if type(grammar_spec) is not str: | ||
ctx = compiler.compile_builtin_json_grammar() | ||
else: | ||
ctx = compiler.compile_json_schema(grammar_spec) | ||
elif request_type == "grammar": | ||
ctx = compiler.compile_grammar(grammar_spec) | ||
else: | ||
raise ValueError("grammar is not of valid supported types.") | ||
return Grammar.from_backend( | ||
self._backend, | ||
matcher=xgr.GrammarMatcher(ctx), | ||
vocab_size=self.model_config.hf_text_config.vocab_size, | ||
ctx=ctx) | ||
|
||
def flush(self): | ||
super().flush() | ||
if self._compiler: self._compiler.clear_cache() | ||
for compiler in self._compiler_cache.values(): | ||
compiler.clear_cache() | ||
self._compiler_cache.clear() | ||
|
||
@classmethod | ||
def get_compiler( | ||
cls, | ||
tokenizer: PreTrainedTokenizer, | ||
*, | ||
max_threads: int = 8, | ||
# passthrough to TokenizerInfo | ||
vocab_size: int | None = None, | ||
stop_token_ids: list[int] | int | None = None | ||
) -> xgr.GrammarCompiler: | ||
cache_key = str(hash(tokenizer)) | ||
if cache_key not in cls._compiler_cache: | ||
tokenizer_info = xgr.TokenizerInfo.from_huggingface( | ||
tokenizer, | ||
stop_token_ids=stop_token_ids, | ||
vocab_size=vocab_size) | ||
cls._compiler_cache[cache_key] = xgr.GrammarCompiler( | ||
tokenizer_info, max_threads=max_threads) | ||
return cls._compiler_cache[cache_key] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.