Skip to content

Commit d719c93

Browse files
committed
feat: initial guided decoding implementation on scheduler
Signed-off-by: Aaron Pham <[email protected]>
1 parent 6dd94db commit d719c93

File tree

8 files changed

+578
-39
lines changed

8 files changed

+578
-39
lines changed

vllm/utils.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import threading
2323
import time
2424
import traceback
25+
import types
2526
import uuid
2627
import warnings
2728
import weakref
@@ -2206,3 +2207,71 @@ def run_method(obj: Any, method: Union[str, bytes, Callable], args: Tuple[Any],
22062207
else:
22072208
func = partial(method, obj) # type: ignore
22082209
return func(*args, **kwargs)
2210+
2211+
2212+
class LazyLoader(types.ModuleType):
2213+
"""
2214+
LazyLoader module borrowed from Tensorflow
2215+
https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/util/lazy_loader.py
2216+
with a addition of "module caching". This will throw an
2217+
exception if module cannot be imported.
2218+
2219+
Lazily import a module, mainly to avoid pulling in large dependencies.
2220+
`contrib`, and `ffmpeg` are examples of modules that are large and not always
2221+
needed, and this allows them to only be loaded when they are used.
2222+
"""
2223+
2224+
def __init__(
2225+
self,
2226+
local_name: str,
2227+
parent_module_globals: Dict[str, Any],
2228+
name: str,
2229+
warning: Optional[str] = None,
2230+
exc_msg: Optional[str] = None,
2231+
exc: Type[Exception] = Exception,
2232+
):
2233+
self._local_name = local_name
2234+
self._parent_module_globals = parent_module_globals
2235+
self._warning = warning
2236+
self._exc_msg = exc_msg
2237+
self._exc = exc
2238+
self._module: types.ModuleType | None = None
2239+
2240+
super().__init__(str(name))
2241+
2242+
def _load(self) -> types.ModuleType:
2243+
"""Load the module and insert it into the parent's globals."""
2244+
from . import warn_deprecated
2245+
2246+
# Import the target module and insert it into the parent's namespace
2247+
try:
2248+
module = importlib.import_module(self.__name__)
2249+
self._parent_module_globals[self._local_name] = module
2250+
# The additional add to sys.modules ensures library is actually loaded.
2251+
sys.modules[self._local_name] = module
2252+
except ModuleNotFoundError as err:
2253+
raise self._exc(f"{self._exc_msg} (reason: {err})") from None
2254+
2255+
# Emit a warning if one was specified
2256+
if self._warning:
2257+
warnings.warn(self._warning,
2258+
category=DeprecationWarning,
2259+
stacklevel=4)
2260+
# Make sure to only warn once.
2261+
self._warning = None
2262+
2263+
# Update this object's dict so that if someone keeps a reference to the
2264+
# LazyLoader, lookups are efficient (__getattr__ is only called on lookups
2265+
# that fail).
2266+
self.__dict__.update(module.__dict__)
2267+
return module
2268+
2269+
def __getattr__(self, item: Any) -> Any:
2270+
if self._module is None:
2271+
self._module = self._load()
2272+
return getattr(self._module, item)
2273+
2274+
def __dir__(self) -> List[str]:
2275+
if self._module is None:
2276+
self._module = self._load()
2277+
return dir(self._module)
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
from __future__ import annotations
2+
3+
import copy
4+
import threading
5+
from abc import ABC, abstractmethod
6+
from concurrent.futures import ThreadPoolExecutor
7+
from dataclasses import dataclass
8+
from typing import (TYPE_CHECKING, Any, Generic, Literal, Optional, TypeVar,
9+
get_args)
10+
11+
from transformers import PreTrainedTokenizer
12+
13+
from vllm.config import DecodingConfig, ModelConfig
14+
from vllm.logger import init_logger
15+
from vllm.utils import LazyLoader
16+
from vllm.v1.request import GuidedDecodingKey, Request, RequestStatus
17+
18+
from .grammar import Grammar
19+
20+
if TYPE_CHECKING:
21+
import xgrammar as xgr
22+
from transformers import PreTrainedTokenizer
23+
from typing_extensions import LiteralString
24+
25+
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
26+
else:
27+
xgr = LazyLoader("xgr", globals(), "xgrammar")
28+
29+
logger = init_logger(__name__)
30+
31+
__all__ = ["Grammar", "GuidedDecodingManager"]
32+
33+
34+
@dataclass
35+
class GrammarCache:
36+
value: Grammar | None
37+
event: threading.Event
38+
39+
40+
T = TypeVar("T", bound=str)
41+
42+
43+
class GuidedDecodingManager(ABC, Generic[T]):
44+
45+
@abstractmethod
46+
def initialize_cache(self, key: GuidedDecodingKey) -> Grammar:
47+
...
48+
49+
def flush(self):
50+
with self._lock:
51+
self.cache.clear()
52+
53+
def cache_grammar(self, request: Request):
54+
return self.executor.submit(self._add_grammar_to_cache, request)
55+
56+
def get_grammar(self, request: Request):
57+
with self._lock:
58+
entry = self.cache.get(request.guided_decoding_key)
59+
if entry is None or not entry.event.is_set(): return None
60+
return copy.copy(entry.value) if entry.value else None
61+
62+
def should_add(self, request: Request):
63+
if not request.use_guided_decoding: return False
64+
request.grammar = self.get_grammar(request)
65+
if not request.grammar:
66+
request.grammar = self.cache_grammar(request)
67+
request.status = RequestStatus.WAITING_FOR_FSM
68+
return True
69+
return False
70+
71+
def _add_grammar_to_cache(self, request: Request):
72+
key = request.guided_decoding_key
73+
with self._lock:
74+
cache_hit = False
75+
if key in self.cache:
76+
cache_hit, entry = True, self.cache[key]
77+
else:
78+
entry = GrammarCache(None, threading.Event())
79+
self.cache[key] = entry
80+
81+
if cache_hit:
82+
entry.event.wait()
83+
else:
84+
entry.value = self.initialize_cache(key)
85+
entry.event.set()
86+
return copy.copy(entry.value) if entry.value else None
87+
88+
@classmethod
89+
def from_backend(cls, /, backend: LiteralString = "xgrammar", *,
90+
tokenizer_group: BaseTokenizerGroup,
91+
model_config: ModelConfig) -> GuidedDecodingManager[T]:
92+
manager_cls = cls._registry.get(backend)
93+
if manager_cls is None: raise ValueError( f"Backend '{backend}' not found in registry. Available backends: {list(cls._registry)}")
94+
return manager_cls(tokenizer_group=tokenizer_group, model_config=model_config)
95+
96+
_registry: dict[str, type[GuidedDecodingManager[T]]] = {}
97+
_backend: T
98+
99+
def __init__(self, *, tokenizer_group: BaseTokenizerGroup, model_config: ModelConfig):
100+
self.model_config = model_config
101+
self.tokenizer = tokenizer_group.get_lora_tokenizer(None)
102+
self.cache: dict[GuidedDecodingKey, GrammarCache] = {}
103+
self.executor = ThreadPoolExecutor()
104+
self._lock = threading.Lock()
105+
106+
def __init_subclass__(cls, **kwargs: Any):
107+
if not hasattr(cls, '__orig_bases__'):
108+
raise TypeError(
109+
f"{cls.__qualname__} must be subclass of GuidedDecodingManager"
110+
)
111+
112+
backend = None
113+
for base in cls.__orig_bases__:
114+
if (origin := get_args(base)) and issubclass(
115+
base.__origin__, GuidedDecodingManager):
116+
backend = get_args(origin[0])[0]
117+
break
118+
119+
if backend is None:
120+
raise TypeError(
121+
f"Class {cls.__qualname__} must specify backend as a Literal type"
122+
)
123+
124+
if backend in cls._registry:
125+
name = cls._registry[backend].__qualname__
126+
raise ValueError(
127+
f"Backend '{backend}' is already registered to {name}")
128+
129+
# Set the backend value from the Literal type
130+
cls._backend = backend
131+
cls._registry[backend] = cls
132+
133+
134+
class XGrammarManager(GuidedDecodingManager[Literal["xgrammar"]]):
135+
# cache GrammarCompiler instances based on given tokenizer
136+
_compiler_cache: dict[str, xgr.GrammarCompiler] = {}
137+
_compiler: xgr.GrammarCompiler | None = None
138+
139+
def initialize_cache(self, key: GuidedDecodingKey) -> Grammar:
140+
request_type, grammar_spec = key
141+
compiler = XGrammarManager.get_compiler(self.tokenizer)
142+
if request_type == "json":
143+
if type(grammar_spec) is not str:
144+
ctx = compiler.compile_builtin_json_grammar()
145+
else:
146+
ctx = compiler.compile_json_schema(grammar_spec)
147+
elif request_type == "grammar":
148+
ctx = compiler.compile_grammar(grammar_spec)
149+
else:
150+
raise ValueError("grammar is not of valid supported types.")
151+
return Grammar.from_backend(
152+
self._backend,
153+
matcher=xgr.GrammarMatcher(ctx),
154+
vocab_size=self.model_config.hf_text_config.vocab_size,
155+
ctx=ctx)
156+
157+
def flush(self):
158+
super().flush()
159+
if self._compiler: self._compiler.clear_cache()
160+
for compiler in self._compiler_cache.values():
161+
compiler.clear_cache()
162+
self._compiler_cache.clear()
163+
164+
@classmethod
165+
def get_compiler(
166+
cls,
167+
tokenizer: PreTrainedTokenizer,
168+
*,
169+
max_threads: int = 8,
170+
# passthrough to TokenizerInfo
171+
vocab_size: int | None = None,
172+
stop_token_ids: list[int] | int | None = None
173+
) -> xgr.GrammarCompiler:
174+
cache_key = str(hash(tokenizer))
175+
if cache_key not in cls._compiler_cache:
176+
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
177+
tokenizer,
178+
stop_token_ids=stop_token_ids,
179+
vocab_size=vocab_size)
180+
cls._compiler_cache[cache_key] = xgr.GrammarCompiler(
181+
tokenizer_info, max_threads=max_threads)
182+
return cls._compiler_cache[cache_key]

0 commit comments

Comments
 (0)