|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import copy |
| 4 | +import threading |
| 5 | +from abc import ABC, abstractmethod |
| 6 | +from concurrent.futures import ThreadPoolExecutor |
| 7 | +from dataclasses import dataclass |
| 8 | +from typing import (TYPE_CHECKING, Any, Generic, Literal, Optional, TypeVar, |
| 9 | + get_args) |
| 10 | + |
| 11 | +from transformers import PreTrainedTokenizer |
| 12 | + |
| 13 | +from vllm.config import DecodingConfig, ModelConfig |
| 14 | +from vllm.logger import init_logger |
| 15 | +from vllm.utils import LazyLoader |
| 16 | +from vllm.v1.request import GuidedDecodingKey, Request, RequestStatus |
| 17 | + |
| 18 | +from .grammar import Grammar |
| 19 | + |
| 20 | +if TYPE_CHECKING: |
| 21 | + import xgrammar as xgr |
| 22 | + from transformers import PreTrainedTokenizer |
| 23 | + from typing_extensions import LiteralString |
| 24 | + |
| 25 | + from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup |
| 26 | +else: |
| 27 | + xgr = LazyLoader("xgr", globals(), "xgrammar") |
| 28 | + |
| 29 | +logger = init_logger(__name__) |
| 30 | + |
| 31 | +__all__ = ["Grammar", "GuidedDecodingManager"] |
| 32 | + |
| 33 | + |
| 34 | +@dataclass |
| 35 | +class GrammarCache: |
| 36 | + value: Grammar | None |
| 37 | + event: threading.Event |
| 38 | + |
| 39 | + |
| 40 | +T = TypeVar("T", bound=str) |
| 41 | + |
| 42 | + |
| 43 | +class GuidedDecodingManager(ABC, Generic[T]): |
| 44 | + |
| 45 | + @abstractmethod |
| 46 | + def initialize_cache(self, key: GuidedDecodingKey) -> Grammar: |
| 47 | + ... |
| 48 | + |
| 49 | + def flush(self): |
| 50 | + with self._lock: |
| 51 | + self.cache.clear() |
| 52 | + |
| 53 | + def cache_grammar(self, request: Request): |
| 54 | + return self.executor.submit(self._add_grammar_to_cache, request) |
| 55 | + |
| 56 | + def get_grammar(self, request: Request): |
| 57 | + with self._lock: |
| 58 | + entry = self.cache.get(request.guided_decoding_key) |
| 59 | + if entry is None or not entry.event.is_set(): return None |
| 60 | + return copy.copy(entry.value) if entry.value else None |
| 61 | + |
| 62 | + def should_add(self, request: Request): |
| 63 | + if not request.use_guided_decoding: return False |
| 64 | + request.grammar = self.get_grammar(request) |
| 65 | + if not request.grammar: |
| 66 | + request.grammar = self.cache_grammar(request) |
| 67 | + request.status = RequestStatus.WAITING_FOR_FSM |
| 68 | + return True |
| 69 | + return False |
| 70 | + |
| 71 | + def _add_grammar_to_cache(self, request: Request): |
| 72 | + key = request.guided_decoding_key |
| 73 | + with self._lock: |
| 74 | + cache_hit = False |
| 75 | + if key in self.cache: |
| 76 | + cache_hit, entry = True, self.cache[key] |
| 77 | + else: |
| 78 | + entry = GrammarCache(None, threading.Event()) |
| 79 | + self.cache[key] = entry |
| 80 | + |
| 81 | + if cache_hit: |
| 82 | + entry.event.wait() |
| 83 | + else: |
| 84 | + entry.value = self.initialize_cache(key) |
| 85 | + entry.event.set() |
| 86 | + return copy.copy(entry.value) if entry.value else None |
| 87 | + |
| 88 | + @classmethod |
| 89 | + def from_backend(cls, /, backend: LiteralString = "xgrammar", *, |
| 90 | + tokenizer_group: BaseTokenizerGroup, |
| 91 | + model_config: ModelConfig) -> GuidedDecodingManager[T]: |
| 92 | + manager_cls = cls._registry.get(backend) |
| 93 | + if manager_cls is None: raise ValueError( f"Backend '{backend}' not found in registry. Available backends: {list(cls._registry)}") |
| 94 | + return manager_cls(tokenizer_group=tokenizer_group, model_config=model_config) |
| 95 | + |
| 96 | + _registry: dict[str, type[GuidedDecodingManager[T]]] = {} |
| 97 | + _backend: T |
| 98 | + |
| 99 | + def __init__(self, *, tokenizer_group: BaseTokenizerGroup, model_config: ModelConfig): |
| 100 | + self.model_config = model_config |
| 101 | + self.tokenizer = tokenizer_group.get_lora_tokenizer(None) |
| 102 | + self.cache: dict[GuidedDecodingKey, GrammarCache] = {} |
| 103 | + self.executor = ThreadPoolExecutor() |
| 104 | + self._lock = threading.Lock() |
| 105 | + |
| 106 | + def __init_subclass__(cls, **kwargs: Any): |
| 107 | + if not hasattr(cls, '__orig_bases__'): |
| 108 | + raise TypeError( |
| 109 | + f"{cls.__qualname__} must be subclass of GuidedDecodingManager" |
| 110 | + ) |
| 111 | + |
| 112 | + backend = None |
| 113 | + for base in cls.__orig_bases__: |
| 114 | + if (origin := get_args(base)) and issubclass( |
| 115 | + base.__origin__, GuidedDecodingManager): |
| 116 | + backend = get_args(origin[0])[0] |
| 117 | + break |
| 118 | + |
| 119 | + if backend is None: |
| 120 | + raise TypeError( |
| 121 | + f"Class {cls.__qualname__} must specify backend as a Literal type" |
| 122 | + ) |
| 123 | + |
| 124 | + if backend in cls._registry: |
| 125 | + name = cls._registry[backend].__qualname__ |
| 126 | + raise ValueError( |
| 127 | + f"Backend '{backend}' is already registered to {name}") |
| 128 | + |
| 129 | + # Set the backend value from the Literal type |
| 130 | + cls._backend = backend |
| 131 | + cls._registry[backend] = cls |
| 132 | + |
| 133 | + |
| 134 | +class XGrammarManager(GuidedDecodingManager[Literal["xgrammar"]]): |
| 135 | + # cache GrammarCompiler instances based on given tokenizer |
| 136 | + _compiler_cache: dict[str, xgr.GrammarCompiler] = {} |
| 137 | + _compiler: xgr.GrammarCompiler | None = None |
| 138 | + |
| 139 | + def initialize_cache(self, key: GuidedDecodingKey) -> Grammar: |
| 140 | + request_type, grammar_spec = key |
| 141 | + compiler = XGrammarManager.get_compiler(self.tokenizer) |
| 142 | + if request_type == "json": |
| 143 | + if type(grammar_spec) is not str: |
| 144 | + ctx = compiler.compile_builtin_json_grammar() |
| 145 | + else: |
| 146 | + ctx = compiler.compile_json_schema(grammar_spec) |
| 147 | + elif request_type == "grammar": |
| 148 | + ctx = compiler.compile_grammar(grammar_spec) |
| 149 | + else: |
| 150 | + raise ValueError("grammar is not of valid supported types.") |
| 151 | + return Grammar.from_backend( |
| 152 | + self._backend, |
| 153 | + matcher=xgr.GrammarMatcher(ctx), |
| 154 | + vocab_size=self.model_config.hf_text_config.vocab_size, |
| 155 | + ctx=ctx) |
| 156 | + |
| 157 | + def flush(self): |
| 158 | + super().flush() |
| 159 | + if self._compiler: self._compiler.clear_cache() |
| 160 | + for compiler in self._compiler_cache.values(): |
| 161 | + compiler.clear_cache() |
| 162 | + self._compiler_cache.clear() |
| 163 | + |
| 164 | + @classmethod |
| 165 | + def get_compiler( |
| 166 | + cls, |
| 167 | + tokenizer: PreTrainedTokenizer, |
| 168 | + *, |
| 169 | + max_threads: int = 8, |
| 170 | + # passthrough to TokenizerInfo |
| 171 | + vocab_size: int | None = None, |
| 172 | + stop_token_ids: list[int] | int | None = None |
| 173 | + ) -> xgr.GrammarCompiler: |
| 174 | + cache_key = str(hash(tokenizer)) |
| 175 | + if cache_key not in cls._compiler_cache: |
| 176 | + tokenizer_info = xgr.TokenizerInfo.from_huggingface( |
| 177 | + tokenizer, |
| 178 | + stop_token_ids=stop_token_ids, |
| 179 | + vocab_size=vocab_size) |
| 180 | + cls._compiler_cache[cache_key] = xgr.GrammarCompiler( |
| 181 | + tokenizer_info, max_threads=max_threads) |
| 182 | + return cls._compiler_cache[cache_key] |
0 commit comments