22
22
23
23
def maybe_backend_fallback (
24
24
guided_params : GuidedDecodingParams ) -> GuidedDecodingParams :
25
+
26
+ def fallback_or_error (guided_params : GuidedDecodingParams , message : str ,
27
+ fallback : str ) -> None :
28
+ """Change the backend to the specified fallback with a warning log,
29
+ or raise a ValueError if the `no-fallback` option is specified."""
30
+ if guided_params .no_fallback ():
31
+ raise ValueError (message )
32
+
33
+ logger .warning ("%s Falling back to use %s instead." , message , fallback )
34
+ guided_params .backend = fallback
35
+
25
36
# lm-format-enforce doesn't support grammar, fallback to xgrammar
26
- if guided_params .backend == "lm-format-enforcer" :
37
+ if guided_params .backend_name == "lm-format-enforcer" :
27
38
if guided_params .grammar is not None :
28
- logger . warning (
29
- "lm-format-enforcer does not support grammar guided decoding. "
30
- "Falling back to use xgrammar instead." )
31
- guided_params . backend = "xgrammar"
39
+ fallback_or_error (
40
+ guided_params ,
41
+ "lm-format-enforcer does not support grammar guided decoding." ,
42
+ "xgrammar" )
32
43
33
44
# lm-format-enforcer doesn't support some JSON schema features
34
45
elif (guided_params .json is not None
35
46
and has_lmf_unsupported_json_features (guided_params .json )):
36
- logger .warning (
47
+ fallback_or_error (
48
+ guided_params ,
37
49
"lm-format-enforcer does not support advanced JSON schema "
38
- "features like patterns or numeric ranges. "
39
- "Falling back to use outlines instead." )
40
- guided_params .backend = "outlines"
50
+ "features like patterns or numeric ranges." , "outlines" )
41
51
42
- if guided_params .backend == "xgrammar" :
52
+ if guided_params .backend_name == "xgrammar" :
43
53
from vllm .model_executor .guided_decoding .xgrammar_decoding import (
44
54
xgr_installed )
45
55
# xgrammar only has x86 wheels for linux, fallback to outlines
46
56
from vllm .platforms import current_platform
47
57
if current_platform .get_cpu_architecture () is not CpuArchEnum .X86 :
48
- logger . warning ( "xgrammar is only supported on x86 CPUs. "
49
- "Falling back to use outlines instead." )
50
- guided_params . backend = "outlines"
58
+ fallback_or_error ( guided_params ,
59
+ "xgrammar is only supported on x86 CPUs." ,
60
+ "outlines" )
51
61
52
62
# xgrammar doesn't support regex, fallback to outlines
53
63
if guided_params .regex is not None :
54
- logger . warning ( "xgrammar does not support regex guided decoding. "
55
- "Falling back to use outlines instead." )
56
- guided_params . backend = " outlines"
64
+ fallback_or_error (
65
+ guided_params ,
66
+ "xgrammar does not support regex guided decoding." , " outlines")
57
67
58
68
# xgrammar doesn't support some JSON schema features
59
69
elif (guided_params .json is not None
60
70
and has_xgrammar_unsupported_json_features (guided_params .json )):
61
- logger .warning (
71
+ fallback_or_error (
72
+ guided_params ,
62
73
"xgrammar does not support advanced JSON schema features like "
63
- "patterns or numeric ranges. "
64
- "Falling back to use outlines instead." )
65
- guided_params .backend = "outlines"
74
+ "enums, patterns or numeric ranges." , "outlines" )
66
75
67
76
# xgrammar only supports GBNF grammars, so we must convert Lark.
68
77
# We must check if the grammar is likely Lark and if that
@@ -72,25 +81,23 @@ def maybe_backend_fallback(
72
81
try :
73
82
convert_lark_to_gbnf (guided_params .grammar )
74
83
except Exception :
75
- logger .warning (
84
+ fallback_or_error (
85
+ guided_params ,
76
86
"xgrammar does not support Lark grammars and the "
77
- "grammar failed to convert to GBNF. "
78
- "Falling back to use outlines instead." )
79
- guided_params .backend = "outlines"
87
+ "grammar failed to convert to GBNF." , "outlines" )
80
88
81
89
# If the xgrammar module cannot be imported successfully,
82
90
# we should still allow users to use guided decoding with a fallback.
83
91
elif not xgr_installed :
84
- logger . warning ( "xgrammar module cannot be imported successfully. "
85
- "Falling back to use outlines instead." )
86
- guided_params . backend = " outlines"
92
+ fallback_or_error (
93
+ guided_params ,
94
+ "xgrammar module cannot be imported successfully." , " outlines")
87
95
88
- if (guided_params .backend == "outlines"
96
+ if (guided_params .backend_name == "outlines"
89
97
and guided_params .json_object is not None ):
90
98
# outlines doesn't support json_object, fallback to xgrammar
91
- logger .warning ("outlines does not support json_object. "
92
- "Falling back to use xgrammar instead." )
93
- guided_params .backend = "xgrammar"
99
+ fallback_or_error (guided_params ,
100
+ "outlines does not support json_object." , "xgrammar" )
94
101
95
102
return guided_params
96
103
@@ -100,18 +107,18 @@ async def get_guided_decoding_logits_processor(
100
107
model_config : ModelConfig ) -> LogitsProcessor | None :
101
108
guided_params = maybe_backend_fallback (guided_params )
102
109
# CFG grammar not supported by LMFE, so we use outlines instead
103
- if guided_params .backend == 'outlines' :
110
+ if guided_params .backend_name == 'outlines' :
104
111
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
105
112
from vllm .model_executor .guided_decoding .outlines_decoding import ( # noqa
106
113
get_outlines_guided_decoding_logits_processor )
107
114
return await get_outlines_guided_decoding_logits_processor (
108
115
guided_params , tokenizer )
109
- if guided_params .backend == 'lm-format-enforcer' :
116
+ if guided_params .backend_name == 'lm-format-enforcer' :
110
117
from vllm .model_executor .guided_decoding .lm_format_enforcer_decoding import ( # noqa
111
118
get_local_lm_format_enforcer_guided_decoding_logits_processor )
112
119
return get_local_lm_format_enforcer_guided_decoding_logits_processor (
113
120
guided_params , tokenizer )
114
- if guided_params .backend == 'xgrammar' :
121
+ if guided_params .backend_name == 'xgrammar' :
115
122
from vllm .model_executor .guided_decoding .xgrammar_decoding import ( # noqa
116
123
get_local_xgrammar_guided_decoding_logits_processor )
117
124
return get_local_xgrammar_guided_decoding_logits_processor (
@@ -127,18 +134,18 @@ def get_local_guided_decoding_logits_processor(
127
134
model_config : ModelConfig ) -> LogitsProcessor | None :
128
135
guided_params = maybe_backend_fallback (guided_params )
129
136
# CFG grammar not supported by LMFE, so we use outlines instead
130
- if guided_params .backend == 'outlines' :
137
+ if guided_params .backend_name == 'outlines' :
131
138
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
132
139
from vllm .model_executor .guided_decoding .outlines_decoding import ( # noqa
133
140
get_local_outlines_guided_decoding_logits_processor )
134
141
return get_local_outlines_guided_decoding_logits_processor (
135
142
guided_params , tokenizer )
136
- if guided_params .backend == 'lm-format-enforcer' :
143
+ if guided_params .backend_name == 'lm-format-enforcer' :
137
144
from vllm .model_executor .guided_decoding .lm_format_enforcer_decoding import ( # noqa
138
145
get_local_lm_format_enforcer_guided_decoding_logits_processor )
139
146
return get_local_lm_format_enforcer_guided_decoding_logits_processor (
140
147
guided_params , tokenizer )
141
- if guided_params .backend == 'xgrammar' :
148
+ if guided_params .backend_name == 'xgrammar' :
142
149
from vllm .model_executor .guided_decoding .xgrammar_decoding import ( # noqa
143
150
get_local_xgrammar_guided_decoding_logits_processor )
144
151
return get_local_xgrammar_guided_decoding_logits_processor (
0 commit comments