Skip to content

Commit 48260e5

Browse files
mhendreyShangmingCaiyoukaichaohmellorterrytangyuan
authored andcommitted
[Frontend] generation_config.json for maximum tokens(vllm-project#12242)
Signed-off-by: Matthew Hendrey <[email protected]> Signed-off-by: Shangming Cai <[email protected]> Signed-off-by: youkaichao <[email protected]> Signed-off-by: Harry Mellor <[email protected]> Signed-off-by: Yuan Tang <[email protected]> Signed-off-by: Isotr0py <[email protected]> Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: Chen Zhang <[email protected]> Signed-off-by: wangxiyuan <[email protected]> Co-authored-by: shangmingc <[email protected]> Co-authored-by: youkaichao <[email protected]> Co-authored-by: Harry Mellor <[email protected]> Co-authored-by: Yuan Tang <[email protected]> Co-authored-by: Isotr0py <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Co-authored-by: Chen Zhang <[email protected]> Co-authored-by: wangxiyuan <[email protected]>
1 parent 2192644 commit 48260e5

File tree

4 files changed

+145
-9
lines changed

4 files changed

+145
-9
lines changed

tests/entrypoints/openai/test_serving_chat.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,116 @@ def test_serving_chat_should_set_correct_max_tokens():
103103

104104
assert mock_engine.generate.call_args.args[1].max_tokens == 10
105105

106+
# Setting server's max_tokens in the generation_config.json
107+
# lower than context_window - prompt_tokens
108+
mock_model_config = MockModelConfig()
109+
mock_model_config.diff_sampling_param = {
110+
"max_tokens": 10 # Setting server-side max_tokens limit
111+
}
112+
113+
# Reinitialize the engine with new settings
114+
mock_engine = MagicMock(spec=MQLLMEngineClient)
115+
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
116+
mock_engine.errored = False
117+
118+
# Initialize the serving chat
119+
models = OpenAIServingModels(engine_client=mock_engine,
120+
base_model_paths=BASE_MODEL_PATHS,
121+
model_config=mock_model_config)
122+
serving_chat = OpenAIServingChat(mock_engine,
123+
mock_model_config,
124+
models,
125+
response_role="assistant",
126+
chat_template=CHAT_TEMPLATE,
127+
chat_template_content_format="auto",
128+
request_logger=None)
129+
130+
# Test Case 1: No max_tokens specified in request
131+
req = ChatCompletionRequest(
132+
model=MODEL_NAME,
133+
messages=[{
134+
"role": "user",
135+
"content": "what is 1+1?"
136+
}],
137+
guided_decoding_backend="outlines",
138+
)
139+
140+
with suppress(Exception):
141+
asyncio.run(serving_chat.create_chat_completion(req))
142+
143+
assert mock_engine.generate.call_args.args[1].max_tokens == 10
144+
145+
# Test Case 2: Request's max_tokens set higher than server accepts
146+
req.max_tokens = 15
147+
148+
with suppress(Exception):
149+
asyncio.run(serving_chat.create_chat_completion(req))
150+
151+
assert mock_engine.generate.call_args.args[1].max_tokens == 10
152+
153+
# Test Case 3: Request's max_tokens set lower than server accepts
154+
req.max_tokens = 5
155+
156+
with suppress(Exception):
157+
asyncio.run(serving_chat.create_chat_completion(req))
158+
159+
assert mock_engine.generate.call_args.args[1].max_tokens == 5
160+
161+
# Setting server's max_tokens in the generation_config.json
162+
# higher than context_window - prompt_tokens
163+
mock_model_config = MockModelConfig()
164+
mock_model_config.diff_sampling_param = {
165+
"max_tokens": 200 # Setting server-side max_tokens limit
166+
}
167+
168+
# Reinitialize the engine with new settings
169+
mock_engine = MagicMock(spec=MQLLMEngineClient)
170+
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
171+
mock_engine.errored = False
172+
173+
# Initialize the serving chat
174+
models = OpenAIServingModels(engine_client=mock_engine,
175+
base_model_paths=BASE_MODEL_PATHS,
176+
model_config=mock_model_config)
177+
serving_chat = OpenAIServingChat(mock_engine,
178+
mock_model_config,
179+
models,
180+
response_role="assistant",
181+
chat_template=CHAT_TEMPLATE,
182+
chat_template_content_format="auto",
183+
request_logger=None)
184+
185+
# Test case 1: No max_tokens specified, defaults to context_window
186+
req = ChatCompletionRequest(
187+
model=MODEL_NAME,
188+
messages=[{
189+
"role": "user",
190+
"content": "what is 1+1?"
191+
}],
192+
guided_decoding_backend="outlines",
193+
)
194+
195+
with suppress(Exception):
196+
asyncio.run(serving_chat.create_chat_completion(req))
197+
198+
assert mock_engine.generate.call_args.args[1].max_tokens == 93
199+
200+
# Test Case 2: Request's max_tokens set higher than server accepts
201+
req.max_tokens = 100
202+
203+
with suppress(Exception):
204+
asyncio.run(serving_chat.create_chat_completion(req))
205+
206+
assert mock_engine.generate.call_args.args[1].max_tokens == 93
207+
208+
# Test Case 3: Request's max_tokens set lower than server accepts
209+
req.max_tokens = 5
210+
211+
with suppress(Exception):
212+
asyncio.run(serving_chat.create_chat_completion(req))
213+
214+
assert mock_engine.generate.call_args.args[1].max_tokens == 5
215+
106216

107217
def test_serving_chat_could_load_correct_generation_config():
108218

vllm/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,12 +910,18 @@ def get_diff_sampling_param(self) -> Dict[str, Any]:
910910
"top_k",
911911
"top_p",
912912
"min_p",
913+
"max_new_tokens",
913914
]
914915
if any(p in config for p in available_params):
915916
diff_sampling_param = {
916917
p: config.get(p)
917918
for p in available_params if config.get(p) is not None
918919
}
920+
# Huggingface definition of max_new_tokens is equivalent
921+
# to vLLM's max_tokens
922+
if "max_new_tokens" in diff_sampling_param:
923+
diff_sampling_param["max_tokens"] = diff_sampling_param.pop(
924+
"max_new_tokens")
919925
else:
920926
diff_sampling_param = {}
921927
return diff_sampling_param

vllm/engine/arg_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
939939
"Defaults to None, will use the default generation config in vLLM. "
940940
"If set to 'auto', the generation config will be automatically "
941941
"loaded from model. If set to a folder path, the generation config "
942-
"will be loaded from the specified folder path.")
942+
"will be loaded from the specified folder path. If "
943+
"`max_new_tokens` is specified, then it sets a server-wide limit "
944+
"on the number of output tokens for all requests.")
943945

944946
parser.add_argument("--enable-sleep-mode",
945947
action="store_true",

vllm/entrypoints/openai/protocol.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -380,13 +380,17 @@ def to_beam_search_params(
380380
) -> BeamSearchParams:
381381
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
382382
max_tokens = self.max_completion_tokens or self.max_tokens
383-
if max_tokens is None:
384-
max_tokens = default_max_tokens
385383

386384
if default_sampling_params is None:
387385
default_sampling_params = {}
388386
n = self.n if self.n is not None else 1
389387

388+
# Use minimum of context window, user request & server limit.
389+
max_tokens = min(
390+
val for val in (default_max_tokens, max_tokens,
391+
default_sampling_params.get("max_tokens", None))
392+
if val is not None)
393+
390394
if (temperature := self.temperature) is None:
391395
temperature = default_sampling_params.get(
392396
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
@@ -406,11 +410,16 @@ def to_sampling_params(
406410
default_sampling_params: Optional[dict] = None) -> SamplingParams:
407411
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
408412
max_tokens = self.max_completion_tokens or self.max_tokens
409-
if max_tokens is None:
410-
max_tokens = default_max_tokens
411413

412414
if default_sampling_params is None:
413415
default_sampling_params = {}
416+
417+
# Use minimum of context window, user request & server limit.
418+
max_tokens = min(
419+
val for val in (default_max_tokens, max_tokens,
420+
default_sampling_params.get("max_tokens", None))
421+
if val is not None)
422+
414423
# Default parameters
415424
if (repetition_penalty := self.repetition_penalty) is None:
416425
repetition_penalty = default_sampling_params.get(
@@ -740,13 +749,17 @@ def to_beam_search_params(
740749
default_sampling_params: Optional[dict] = None
741750
) -> BeamSearchParams:
742751
max_tokens = self.max_tokens
743-
if max_tokens is None:
744-
max_tokens = default_max_tokens
745752

746753
if default_sampling_params is None:
747754
default_sampling_params = {}
748755
n = self.n if self.n is not None else 1
749756

757+
# Use minimum of context window, user request & server limit.
758+
max_tokens = min(
759+
val for val in (default_max_tokens, max_tokens,
760+
default_sampling_params.get("max_tokens", None))
761+
if val is not None)
762+
750763
if (temperature := self.temperature) is None:
751764
temperature = default_sampling_params.get("temperature", 1.0)
752765

@@ -764,11 +777,16 @@ def to_sampling_params(
764777
logits_processor_pattern: Optional[str],
765778
default_sampling_params: Optional[dict] = None) -> SamplingParams:
766779
max_tokens = self.max_tokens
767-
if max_tokens is None:
768-
max_tokens = default_max_tokens
769780

770781
if default_sampling_params is None:
771782
default_sampling_params = {}
783+
784+
# Use minimum of context window, user request & server limit.
785+
max_tokens = min(
786+
val for val in (default_max_tokens, max_tokens,
787+
default_sampling_params.get("max_tokens", None))
788+
if val is not None)
789+
772790
# Default parameters
773791
if (repetition_penalty := self.repetition_penalty) is None:
774792
repetition_penalty = default_sampling_params.get(

0 commit comments

Comments
 (0)