Skip to content

Commit 9c7468f

Browse files
gcalmetteszRzRzRzRzRzRzR
authored andcommitted
[Bugfix][Frontend] respect provided default guided decoding backend (#15476)
Signed-off-by: Guillaume Calmettes <[email protected]> Signed-off-by: zRzRzRzRzRzRzR <[email protected]>
1 parent 2a91bc1 commit 9c7468f

File tree

2 files changed

+78
-5
lines changed

2 files changed

+78
-5
lines changed

tests/test_sampling_params.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,89 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""Tests for the SamplingParams class.
33
"""
4+
5+
import pytest
6+
47
from vllm import SamplingParams
8+
from vllm.config import ModelConfig
9+
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
10+
11+
MODEL_NAME = "Qwen/Qwen1.5-7B"
512

613

714
def test_max_tokens_none():
815
"""max_tokens=None should be allowed"""
916
SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None)
1017

1118

12-
if __name__ == "__main__":
13-
import pytest
14-
pytest.main([__file__])
19+
@pytest.fixture(scope="module")
20+
def model_config():
21+
return ModelConfig(
22+
MODEL_NAME,
23+
task="auto",
24+
tokenizer=MODEL_NAME,
25+
tokenizer_mode="auto",
26+
trust_remote_code=False,
27+
seed=0,
28+
dtype="float16",
29+
revision=None,
30+
)
31+
32+
33+
@pytest.fixture(scope="module")
34+
def default_max_tokens():
35+
return 4096
36+
37+
38+
def test_sampling_params_from_request_with_no_guided_decoding_backend(
39+
model_config, default_max_tokens):
40+
# guided_decoding_backend is not present at request level
41+
request = ChatCompletionRequest.model_validate({
42+
'messages': [{
43+
'role': 'user',
44+
'content': 'Hello'
45+
}],
46+
'model':
47+
MODEL_NAME,
48+
'response_format': {
49+
'type': 'json_object',
50+
},
51+
})
52+
53+
sampling_params = request.to_sampling_params(
54+
default_max_tokens,
55+
model_config.logits_processor_pattern,
56+
)
57+
# we do not expect any backend to be present and the default
58+
# guided_decoding_backend at engine level will be used.
59+
assert sampling_params.guided_decoding.backend is None
60+
61+
62+
@pytest.mark.parametrize("request_level_guided_decoding_backend,expected",
63+
[("xgrammar", "xgrammar"),
64+
("lm-format-enforcer", "lm-format-enforcer"),
65+
("outlines", "outlines")])
66+
def test_sampling_params_from_request_with_guided_decoding_backend(
67+
request_level_guided_decoding_backend: str, expected: str,
68+
model_config, default_max_tokens):
69+
70+
request = ChatCompletionRequest.model_validate({
71+
'messages': [{
72+
'role': 'user',
73+
'content': 'Hello'
74+
}],
75+
'model':
76+
MODEL_NAME,
77+
'response_format': {
78+
'type': 'json_object',
79+
},
80+
'guided_decoding_backend':
81+
request_level_guided_decoding_backend,
82+
})
83+
84+
sampling_params = request.to_sampling_params(
85+
default_max_tokens,
86+
model_config.logits_processor_pattern,
87+
)
88+
# backend correctly identified in resulting sampling_params
89+
assert sampling_params.guided_decoding.backend == expected

vllm/entrypoints/openai/protocol.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,8 +476,6 @@ def to_sampling_params(
476476
json_schema = self.response_format.json_schema
477477
assert json_schema is not None
478478
self.guided_json = json_schema.json_schema
479-
if self.guided_decoding_backend is None:
480-
self.guided_decoding_backend = "xgrammar"
481479

482480
guided_decoding = GuidedDecodingParams.from_optional(
483481
json=self._get_guided_json_from_tool() or self.guided_json,

0 commit comments

Comments
 (0)