Skip to content

Commit 420f52f

Browse files
committed
chore: --wip--
Signed-off-by: Aaron Pham <[email protected]>
1 parent 2bb535e commit 420f52f

File tree

6 files changed

+119
-146
lines changed

6 files changed

+119
-146
lines changed

vllm/v1/core/guided_decoding/__init__.py

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,59 @@
11
from __future__ import annotations
22

3-
import copy, enum
3+
import copy
44
import threading
55
from concurrent.futures import ThreadPoolExecutor
66
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING, TypeVar
7+
from typing import TYPE_CHECKING, Optional
88

9+
import torch
910
import xgrammar as xgr
1011

11-
from vllm.config import ModelConfig
12-
from vllm.logger import init_logger
12+
from vllm.config import VllmConfig
1313
from vllm.v1.request import GuidedDecodingKey, Request, RequestStatus
1414

15-
from .grammar import Grammar
16-
1715
if TYPE_CHECKING:
18-
from typing_extensions import Self
19-
2016
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
2117

22-
from .grammar import XGrammar
18+
__all__ = ["Grammar", "GuidedDecodingManager"]
19+
2320

24-
logger = init_logger(__name__)
21+
class Grammar:
22+
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string for jump-forward decoding
2523

26-
__all__ = ["Grammar", "GuidedDecodingManager"]
24+
def __init__(self, matcher: xgr.GrammarMatcher, vocab_size: int,
25+
ctx: xgr.CompiledGrammar) -> None:
26+
self.matcher = matcher
27+
self.vocab_size = vocab_size
28+
self.ctx = ctx
29+
30+
def accept_token(self, token: int) -> bool:
31+
# NOTE: accept_token will determines whether we accept this token
32+
# and will also update the machine state
33+
return self.matcher.accept_token(token)
34+
35+
def allocate_bitmask(self, batch_size: int,
36+
vocab_size: int) -> torch.Tensor:
37+
return xgr.allocate_token_bitmask(batch_size, vocab_size)
38+
39+
# this should be ran in parallel with model decoding
40+
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
41+
self.matcher.fill_next_token_bitmask(bitmask, idx)
42+
43+
@staticmethod
44+
def apply_bitmask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
45+
xgr.apply_token_bitmask_inplace(logits, vocab_mask)
46+
47+
def reset(self):
48+
self.matcher.reset()
49+
50+
def copy(self):
51+
return Grammar(matcher=xgr.GrammarMatcher(self.ctx),
52+
vocab_size=self.vocab_size,
53+
ctx=self.ctx)
54+
55+
def __copy__(self):
56+
return self.copy()
2757

2858

2959
@dataclass
@@ -74,20 +104,17 @@ def collect(self, request: Request):
74104
return True
75105
return False
76106

77-
def __init__(self, *, backend: str, tokenizer_group: BaseTokenizerGroup,
78-
model_config: ModelConfig):
79-
self._backend = backend
80-
self.model_config = model_config
107+
def __init__(self, *, vllm_config: VllmConfig,
108+
tokenizer_group: BaseTokenizerGroup):
109+
self.vllm_config = vllm_config
81110
self.tokenizer = tokenizer_group.get_lora_tokenizer(None)
82111
self.grammar_cache: dict[GuidedDecodingKey, GrammarCache] = {}
83112
self.executor = ThreadPoolExecutor()
84113
self._lock = threading.Lock()
85-
cls._registry[backend] = cls
86114

87-
def initialize_cache(self, key: GuidedDecodingKey) -> Self:
115+
def initialize_cache(self, key: GuidedDecodingKey, max_threads: int = 8):
88116
request_type, grammar_spec = key
89-
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
90-
tokenizer, stop_token_ids=stop_token_ids, vocab_size=vocab_size)
117+
tokenizer_info = xgr.TokenizerInfo.from_huggingface(self.tokenizer)
91118
compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=max_threads)
92119
if request_type == "json":
93120
if type(grammar_spec) is not str:
@@ -98,6 +125,7 @@ def initialize_cache(self, key: GuidedDecodingKey) -> Self:
98125
ctx = compiler.compile_grammar(grammar_spec)
99126
else:
100127
raise ValueError("grammar is not of valid supported types.")
101-
return Grammar(matcher=xgr.GrammarMatcher(ctx),
102-
vocab_size=self.model_config.hf_text_config.vocab_size,
103-
ctx=ctx)
128+
return Grammar(
129+
matcher=xgr.GrammarMatcher(ctx),
130+
vocab_size=self.vllm_config.model_config.hf_text_config.vocab_size,
131+
ctx=ctx)

vllm/v1/core/guided_decoding/grammar.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

vllm/v1/core/scheduler.py

Lines changed: 19 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
1515
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
1616
compute_encoder_budget)
17-
from vllm.v1.core.guided_decoding import GuidedDecodingManager
18-
from vllm.v1.core.guided_decoding.grammar import Grammar
17+
from vllm.v1.core.guided_decoding import Grammar
1918
from vllm.v1.core.kv_cache_manager import KVCacheManager
2019
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
2120
from vllm.v1.metrics.stats import SchedulerStats
@@ -40,13 +39,11 @@ def __init__(
4039
cache_config: CacheConfig,
4140
parallel_config: ParallelConfig,
4241
lora_config: Optional[LoRAConfig],
43-
decoding_config: DecodingConfig,
4442
) -> None:
4543
self.scheduler_config = scheduler_config
4644
self.cache_config = cache_config
4745
self.lora_config = lora_config
4846
self.model_config = model_config
49-
self.decoding_config = decoding_config
5047
# TODO: Support LoRA.
5148
assert lora_config is None, "V1 does not support LoRA yet."
5249
# Scheduling constraints.
@@ -103,21 +100,6 @@ def __init__(
103100
self.encoder_cache_manager = EncoderCacheManager(
104101
cache_size=encoder_cache_size)
105102

106-
# A request queue for grammar compilation
107-
self.grammar: Deque[Request] = deque()
108-
# initialize the tokenizer on the scheduler (this is used for constrained decoding)
109-
tokenizer_group = init_tokenizer_from_configs(
110-
model_config=model_config,
111-
scheduler_config=scheduler_config,
112-
parallel_config=parallel_config,
113-
lora_config=lora_config)
114-
tokenizer_group.ping()
115-
# setup guided decoding, right now uses xgrammar
116-
self.guided_decoding_manager = GuidedDecodingManager(
117-
backend=decoding_config.guided_decoding_backend,
118-
tokenizer_group=tokenizer_group,
119-
model_config=model_config)
120-
121103
def schedule(self) -> "SchedulerOutput":
122104
# NOTE(woosuk) on the scheduling algorithm:
123105
# There's no "decoding phase" nor "prefill phase" in the scheduler.
@@ -133,25 +115,6 @@ def schedule(self) -> "SchedulerOutput":
133115
scheduled_running_reqs: List[Request] = []
134116
preempted_reqs: List[Request] = []
135117

136-
# we need to check the grammar queue for any requests that have finished FSM compilation
137-
newly_grammar_reqs: List[Request] = []
138-
scheduled_grammar_reqs: Deque[Request] = deque()
139-
while self.grammar:
140-
request = self.grammar.popleft()
141-
try:
142-
# When request first added via add_request, then it will be a future call
143-
# check timeout and add it directly to previous queue
144-
request.grammar = request.grammar.result(timeout=0.05)
145-
request.status = RequestStatus.WAITING
146-
newly_grammar_reqs.append(request)
147-
except futures._base.TimeoutError:
148-
scheduled_grammar_reqs.append(request)
149-
self.grammar = scheduled_grammar_reqs
150-
151-
# append all newly ready requests to waiting queue with higher priority
152-
for req in newly_grammar_reqs:
153-
self.waiting.appendleft(req)
154-
155118
req_to_new_block_ids: Dict[str, List[int]] = {}
156119
num_scheduled_tokens: Dict[str, int] = {}
157120
token_budget = self.max_num_scheduled_tokens
@@ -238,13 +201,6 @@ def schedule(self) -> "SchedulerOutput":
238201
self.encoder_cache_manager.allocate(request, i)
239202
encoder_budget = new_encoder_budget
240203

241-
# Track if we need guided decoding
242-
# Create individual bitmask for requests with grammar
243-
if request.grammar is not None:
244-
if request.request_id not in guided_decoding_bitmasks:
245-
bitmask = request.grammar.allocate_bitmask(1, vocab_size)
246-
guided_decoding_bitmasks[request.request_id] = bitmask
247-
248204
# Next, schedule the WAITING requests.
249205
if not preempted_reqs:
250206
while self.waiting:
@@ -258,7 +214,8 @@ def schedule(self) -> "SchedulerOutput":
258214
request = self.waiting[0]
259215

260216
# allocate bitmask on request on first round
261-
if request.grammar: request.allocate_grammar_bitmask(vocab_size=vocab_size)
217+
if request.grammar:
218+
request.allocate_grammar_bitmask(vocab_size=vocab_size)
262219

263220
# Get already-cached tokens.
264221
computed_blocks, num_computed_tokens = \
@@ -356,8 +313,12 @@ def schedule(self) -> "SchedulerOutput":
356313
]
357314
running_reqs_data = [
358315
self._make_running_request_data(
359-
req, req_to_new_block_ids[req.request_id],
360-
req.num_computed_tokens, grammar=req.grammar, grammar_bitmask=req.grammar_bitmask) for req in scheduled_running_reqs
316+
req,
317+
req_to_new_block_ids[req.request_id],
318+
req.num_computed_tokens,
319+
grammar=req.grammar,
320+
grammar_bitmask=req.grammar_bitmask)
321+
for req in scheduled_running_reqs
361322
]
362323
preempted_req_ids = {req.request_id for req in preempted_reqs}
363324

@@ -375,7 +336,6 @@ def schedule(self) -> "SchedulerOutput":
375336
# It contains the request IDs that are finished in between
376337
# the previous and the current steps.
377338
finished_req_ids=self.finished_req_ids,
378-
guided_decoding_bitmasks=guided_decoding_bitmasks,
379339
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
380340
)
381341

@@ -398,7 +358,7 @@ def _make_running_request_data(
398358
req_data.new_block_ids = new_block_ids
399359
req_data.num_computed_tokens = num_computed_tokens
400360
req_data.grammar = grammar
401-
req_data.grammar_bitmask=grammar_bitmask
361+
req_data.grammar_bitmask = grammar_bitmask
402362
else:
403363
req_data = RunningRequestData.from_request(request, new_block_ids,
404364
num_computed_tokens)
@@ -480,6 +440,8 @@ def update_from_output(
480440
scheduler_output: "SchedulerOutput",
481441
model_runner_output: "ModelRunnerOutput",
482442
) -> EngineCoreOutputs:
443+
# concern: batchsize >>>1000
444+
# compilation << update
483445
# NOTE(woosuk): This method doesn't consider speculative decoding.
484446
sampled_token_ids = model_runner_output.sampled_token_ids
485447
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
@@ -560,11 +522,7 @@ def _check_stop(self, request: Request) -> bool:
560522

561523
def add_request(self, request: Request) -> None:
562524
self.requests[request.request_id] = request
563-
564-
if self.guided_decoding_manager.collect(request):
565-
self.grammar.append(request)
566-
else:
567-
self.waiting.append(request)
525+
self.waiting.append(request)
568526

569527
def finish_requests(
570528
self,
@@ -648,7 +606,8 @@ def from_request(
648606
sampling_params=request.sampling_params,
649607
block_ids=block_ids,
650608
num_computed_tokens=num_computed_tokens,
651-
grammar=request.grammar, grammar_bitmask=request.grammar_bitmask)
609+
grammar=request.grammar,
610+
grammar_bitmask=request.grammar_bitmask)
652611

653612

654613
@dataclass
@@ -671,7 +630,8 @@ def from_request(
671630
return cls(req_id=request.request_id,
672631
block_ids=block_ids,
673632
num_computed_tokens=num_computed_tokens,
674-
grammar=request.grammar, grammar_bitmask=request.grammar_bitmask)
633+
grammar=request.grammar,
634+
grammar_bitmask=request.grammar_bitmask)
675635

676636

677637
@dataclass
@@ -694,7 +654,8 @@ def from_request(
694654
return cls(req_id=request.request_id,
695655
new_block_ids=new_block_ids,
696656
num_computed_tokens=num_computed_tokens,
697-
grammar=request.grammar, grammar_bitmask=request.grammar_bitmask)
657+
grammar=request.grammar,
658+
grammar_bitmask=request.grammar_bitmask)
698659

699660

700661
@dataclass

vllm/v1/engine/core.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from vllm.transformers_utils.config import (
1717
maybe_register_config_serialize_by_value)
1818
from vllm.utils import get_exception_traceback, zmq_socket_ctx
19+
from vllm.v1.core.guided_decoding import GuidedDecodingManager
1920
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
2021
from vllm.v1.core.scheduler import Scheduler
2122
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
@@ -67,6 +68,28 @@ def __init__(
6768
self.mm_input_mapper_server = MMInputMapperServer(
6869
vllm_config.model_config)
6970

71+
# initialize the tokenizer on the scheduler (this is used for constrained decoding)
72+
tokenizer_group = init_tokenizer_from_configs(
73+
model_config=vllm_config.model_config,
74+
scheduler_config=vllm_config.scheduler_config,
75+
parallel_config=vllm_config.parallel_config,
76+
lora_config=vllm_config.lora_config)
77+
tokenizer_group.ping()
78+
# setup guided decoding, right now uses xgrammar
79+
self.guided_decoding_manager = GuidedDecodingManager(
80+
vllm_config=vllm_config, tokenizer_group=tokenizer_group)
81+
82+
# while self.grammar:
83+
# request = self.grammar.popleft()
84+
# try:
85+
# # When request first added via add_request, then it will be a future call
86+
# # check timeout and add it directly to previous queue
87+
# request.grammar = request.grammar.result(timeout=0.05)
88+
# request.status = RequestStatus.WAITING
89+
# newly_grammar_reqs.append(request)
90+
# except futures._base.TimeoutError:
91+
# scheduled_grammar_reqs.append(request)
92+
7093
def _initialize_kv_caches(self,
7194
vllm_config: VllmConfig) -> Tuple[int, int]:
7295
start = time.time()
@@ -127,6 +150,9 @@ def step(self) -> EngineCoreOutputs:
127150

128151
scheduler_output = self.scheduler.schedule()
129152
output = self.model_executor.execute_model(scheduler_output)
153+
# update FSM async here
154+
# two broadcast (bitmask + calculate) <-- manager
155+
# copy CPU -> CPU IPC (concat multiple bitmask?)
130156
engine_core_outputs = self.scheduler.update_from_output(
131157
scheduler_output, output)
132158
return engine_core_outputs

0 commit comments

Comments
 (0)