Skip to content

Commit 04421df

Browse files
authored
[V1] Prevent xgrammar from breaking TPU support (#14575)
Signed-off-by: Russell Bryant <[email protected]>
1 parent 432d6da commit 04421df

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

vllm/v1/engine/processor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections.abc import Mapping
55
from typing import Optional, Union
66

7+
import vllm.platforms
78
from vllm.config import VllmConfig
89
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
910
PromptType, SingletonInputsAdapter)
@@ -133,6 +134,9 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
133134
if self.vllm_config.speculative_config:
134135
raise ValueError("Structured output is not supported with "
135136
"speculative decoding.")
137+
if vllm.platforms.current_platform.is_tpu():
138+
raise ValueError("Structured output is not supported on TPU.")
139+
136140
validate_structured_output_request(params)
137141

138142
def process_inputs(

vllm/v1/structured_output/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
if TYPE_CHECKING:
1818
import numpy as np
1919
import numpy.typing as npt
20+
import torch
2021
import xgrammar as xgr
2122

2223
from vllm.v1.request import Request
@@ -53,8 +54,7 @@ def __init__(self, vllm_config: VllmConfig, max_cache_size: int = 500):
5354
# compilation, so we set it to half the number of CPUs.
5455
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
5556
self.executor = ThreadPoolExecutor(max_workers=max_workers)
56-
self._grammar_bitmask = xgr.allocate_token_bitmask(
57-
self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size)
57+
self._grammar_bitmask: Optional[torch.Tensor] = None
5858

5959
def __getitem__(self, key: StructuredOutputKey) -> Optional[Grammar]:
6060
# We need to pop and re-insert the grammar here for LRU cache
@@ -134,6 +134,11 @@ def grammar_bitmask(
134134
if not structured_output_request_ids:
135135
return None
136136

137+
if self._grammar_bitmask is None:
138+
self._grammar_bitmask = xgr.allocate_token_bitmask(
139+
self.vllm_config.scheduler_config.max_num_seqs,
140+
self.vocab_size)
141+
137142
# Fill the bitmask using the index of each request equal to its
138143
# position in the batch. Resize the bitmask down to the size of
139144
# the batch.

0 commit comments

Comments
 (0)