Skip to content

Commit be31e4d

Browse files
hmellorlulmer
authored andcommitted
Fix performance when --generation-config is not None (vllm-project#14223)
Signed-off-by: Harry Mellor <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
1 parent 3083ed7 commit be31e4d

File tree

4 files changed

+23
-25
lines changed

4 files changed

+23
-25
lines changed

vllm/entrypoints/llm.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def __init__(
244244
engine_args, usage_context=UsageContext.LLM_CLASS)
245245

246246
self.request_counter = Counter()
247+
self.default_sampling_params: Union[dict[str, Any], None] = None
247248

248249
@staticmethod
249250
def get_engine_class() -> type[LLMEngine]:
@@ -268,10 +269,11 @@ def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
268269
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
269270

270271
def get_default_sampling_params(self) -> SamplingParams:
271-
diff_sampling_param = (
272-
self.llm_engine.model_config.get_diff_sampling_param())
273-
if diff_sampling_param:
274-
return SamplingParams.from_optional(**diff_sampling_param)
272+
if self.default_sampling_params is None:
273+
self.default_sampling_params = (
274+
self.llm_engine.model_config.get_diff_sampling_param())
275+
if self.default_sampling_params:
276+
return SamplingParams.from_optional(**self.default_sampling_params)
275277
return SamplingParams()
276278

277279
@overload

vllm/entrypoints/openai/serving_chat.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,11 @@ def __init__(
105105
"been registered") from e
106106

107107
self.enable_prompt_tokens_details = enable_prompt_tokens_details
108-
diff_sampling_param = self.model_config.get_diff_sampling_param()
109-
if diff_sampling_param:
108+
self.default_sampling_params = (
109+
self.model_config.get_diff_sampling_param())
110+
if self.default_sampling_params:
110111
logger.info("Overwriting default chat sampling param with: %s",
111-
diff_sampling_param)
112+
self.default_sampling_params)
112113

113114
async def create_chat_completion(
114115
self,
@@ -210,17 +211,14 @@ async def create_chat_completion(
210211
sampling_params: Union[SamplingParams, BeamSearchParams]
211212
default_max_tokens = self.max_model_len - len(
212213
engine_prompt["prompt_token_ids"])
213-
# Build default sampling params
214-
default_sampling_params = (
215-
self.model_config.get_diff_sampling_param())
216214
if request.use_beam_search:
217215
sampling_params = request.to_beam_search_params(
218-
default_max_tokens, default_sampling_params)
216+
default_max_tokens, self.default_sampling_params)
219217
else:
220218
sampling_params = request.to_sampling_params(
221219
default_max_tokens,
222220
self.model_config.logits_processor_pattern,
223-
default_sampling_params)
221+
self.default_sampling_params)
224222

225223
self._log_inputs(request_id,
226224
request_prompts[i],

vllm/entrypoints/openai/serving_completion.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,12 @@ def __init__(
5151
models=models,
5252
request_logger=request_logger,
5353
return_tokens_as_token_ids=return_tokens_as_token_ids)
54-
diff_sampling_param = self.model_config.get_diff_sampling_param()
55-
if diff_sampling_param:
54+
self.default_sampling_params = (
55+
self.model_config.get_diff_sampling_param())
56+
if self.default_sampling_params:
5657
logger.info(
5758
"Overwriting default completion sampling param with: %s",
58-
diff_sampling_param)
59+
self.default_sampling_params)
5960

6061
async def create_completion(
6162
self,
@@ -119,17 +120,14 @@ async def create_completion(
119120
sampling_params: Union[SamplingParams, BeamSearchParams]
120121
default_max_tokens = self.max_model_len - len(
121122
engine_prompt["prompt_token_ids"])
122-
# Build default sampling params
123-
default_sampling_params = (
124-
self.model_config.get_diff_sampling_param())
125123
if request.use_beam_search:
126124
sampling_params = request.to_beam_search_params(
127-
default_max_tokens, default_sampling_params)
125+
default_max_tokens, self.default_sampling_params)
128126
else:
129127
sampling_params = request.to_sampling_params(
130128
default_max_tokens,
131129
self.model_config.logits_processor_pattern,
132-
default_sampling_params)
130+
self.default_sampling_params)
133131

134132
request_id_item = f"{request_id}-{i}"
135133

vllm/entrypoints/openai/serving_transcription.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,12 @@ def __init__(
161161
request_logger=request_logger,
162162
return_tokens_as_token_ids=return_tokens_as_token_ids)
163163

164-
diff_sampling_param = self.model_config.get_diff_sampling_param()
165-
if diff_sampling_param:
164+
self.default_sampling_params = (
165+
self.model_config.get_diff_sampling_param())
166+
if self.default_sampling_params:
166167
logger.info(
167168
"Overwriting default completion sampling param with: %s",
168-
diff_sampling_param)
169+
self.default_sampling_params)
169170

170171
async def _preprocess_transcription(
171172
self,
@@ -273,9 +274,8 @@ async def create_transcription(
273274
try:
274275
# TODO(rob): subtract len of tokenized prompt.
275276
default_max_tokens = self.model_config.max_model_len
276-
default_params = self.model_config.get_diff_sampling_param()
277277
sampling_params = request.to_sampling_params(
278-
default_max_tokens, default_params)
278+
default_max_tokens, self.default_sampling_params)
279279

280280
self._log_inputs(
281281
request_id,

0 commit comments

Comments
 (0)