Skip to content

Commit 74fa1d1

Browse files
authored
[Bugfix] Fix OpenAI parallel sampling when using xgrammar (#11637)
Signed-off-by: mgoin <[email protected]>
1 parent a2a40bc commit 74fa1d1

File tree

4 files changed

+17
-13
lines changed

4 files changed

+17
-13
lines changed

tests/entrypoints/openai/test_completion.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
# need to change to match the prompt adapter
2929
PA_NUM_VIRTUAL_TOKENS = 8
3030

31+
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
32+
3133

3234
@pytest.fixture(scope="module")
3335
def zephyr_lora_files():
@@ -635,8 +637,7 @@ async def test_allowed_token_ids(client: openai.AsyncOpenAI):
635637

636638

637639
@pytest.mark.asyncio
638-
@pytest.mark.parametrize("guided_decoding_backend",
639-
["outlines", "lm-format-enforcer"])
640+
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
640641
async def test_guided_json_completion(client: openai.AsyncOpenAI,
641642
guided_decoding_backend: str,
642643
sample_json_schema):
@@ -658,8 +659,7 @@ async def test_guided_json_completion(client: openai.AsyncOpenAI,
658659

659660

660661
@pytest.mark.asyncio
661-
@pytest.mark.parametrize("guided_decoding_backend",
662-
["outlines", "lm-format-enforcer"])
662+
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
663663
async def test_guided_regex_completion(client: openai.AsyncOpenAI,
664664
guided_decoding_backend: str,
665665
sample_regex):
@@ -680,8 +680,7 @@ async def test_guided_regex_completion(client: openai.AsyncOpenAI,
680680

681681

682682
@pytest.mark.asyncio
683-
@pytest.mark.parametrize("guided_decoding_backend",
684-
["outlines", "lm-format-enforcer"])
683+
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
685684
async def test_guided_choice_completion(client: openai.AsyncOpenAI,
686685
guided_decoding_backend: str,
687686
sample_guided_choice):
@@ -761,8 +760,7 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
761760

762761

763762
@pytest.mark.asyncio
764-
@pytest.mark.parametrize("guided_decoding_backend",
765-
["outlines", "lm-format-enforcer"])
763+
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
766764
async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
767765
guided_decoding_backend: str,
768766
sample_json_schema, sample_regex):

vllm/model_executor/guided_decoding/xgrammar_decoding.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# noqa: UP007
22
from __future__ import annotations
33

4+
import copy
45
import json
56
from dataclasses import dataclass, field
67
from typing import TYPE_CHECKING, Any
@@ -309,3 +310,7 @@ def __call__(self, input_ids: list[int],
309310
scores = scores.to(device_type).squeeze()
310311

311312
return scores
313+
314+
def clone(self) -> XGrammarLogitsProcessor:
315+
"""Deepcopy due to per-sequence state in the matchers"""
316+
return copy.deepcopy(self)

vllm/sampling_params.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -450,15 +450,16 @@ def all_stop_token_ids(self) -> Set[int]:
450450
return self._all_stop_token_ids
451451

452452
def clone(self) -> "SamplingParams":
453-
"""Deep copy excluding LogitsProcessor objects.
453+
"""Deep copy, but maybe not the LogitsProcessor objects.
454454
455-
LogitsProcessor objects are excluded because they may contain an
456-
arbitrary, nontrivial amount of data.
455+
LogitsProcessor objects may contain an arbitrary, nontrivial amount of
456+
data that is expensive to copy. However, if not copied, the processor
457+
needs to support parallel decoding for multiple sequences
457458
See https://github.com/vllm-project/vllm/issues/3087
458459
"""
459460

460461
logit_processor_refs = None if self.logits_processors is None else {
461-
id(lp): lp
462+
id(lp): lp.clone() if hasattr(lp, 'clone') else lp
462463
for lp in self.logits_processors
463464
}
464465
return copy.deepcopy(self, memo=logit_processor_refs)

vllm/sequence.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1372,7 +1372,7 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
13721372
@staticmethod
13731373
def add_request(request_id: str, engine, params, **kwargs):
13741374
original_params = params
1375-
params = copy.deepcopy(original_params)
1375+
params = original_params.clone()
13761376
params.n = 1
13771377
group = ParallelSampleSequenceGroup(request_id)
13781378
seqs = []

0 commit comments

Comments
 (0)