Skip to content

Commit e5b2f16

Browse files
authored
[Frontend] Do prompt_logprobs clamping for chat as well as completions (#14225)
Signed-off-by: Harry Mellor <[email protected]>
1 parent 9badee5 commit e5b2f16

File tree

3 files changed

+22
-11
lines changed

3 files changed

+22
-11
lines changed

vllm/entrypoints/openai/serving_chat.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
RequestResponseMetadata, ToolCall, UsageInfo)
2525
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
2626
ReasoningParserManager)
27-
from vllm.entrypoints.openai.serving_engine import OpenAIServing
27+
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
28+
clamp_prompt_logprobs)
2829
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
2930
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
3031
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
@@ -844,7 +845,7 @@ async def chat_completion_full_generator(
844845
model=model_name,
845846
choices=choices,
846847
usage=usage,
847-
prompt_logprobs=final_res.prompt_logprobs,
848+
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
848849
)
849850

850851
return response

vllm/entrypoints/openai/serving_completion.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
RequestResponseMetadata,
2424
UsageInfo)
2525
# yapf: enable
26-
from vllm.entrypoints.openai.serving_engine import OpenAIServing
26+
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
27+
clamp_prompt_logprobs)
2728
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
2829
from vllm.logger import init_logger
2930
from vllm.outputs import RequestOutput
@@ -394,13 +395,7 @@ def request_output_to_completion_response(
394395
for final_res in final_res_batch:
395396
prompt_token_ids = final_res.prompt_token_ids
396397
assert prompt_token_ids is not None
397-
prompt_logprobs = final_res.prompt_logprobs
398-
if prompt_logprobs:
399-
for logprob_dict in prompt_logprobs:
400-
if logprob_dict:
401-
for logprob_values in logprob_dict.values():
402-
if logprob_values.logprob == float('-inf'):
403-
logprob_values.logprob = -9999.0
398+
prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
404399
prompt_text = final_res.prompt
405400

406401
token_ids: GenericSequence[int]

vllm/entrypoints/openai/serving_engine.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from vllm.pooling_params import PoolingParams
4343
from vllm.prompt_adapter.request import PromptAdapterRequest
4444
from vllm.sampling_params import BeamSearchParams, SamplingParams
45-
from vllm.sequence import Logprob
45+
from vllm.sequence import Logprob, PromptLogprobs
4646
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
4747
log_tracing_disabled_warning)
4848
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
@@ -535,3 +535,18 @@ def _get_model_name(self,
535535
if model_name is None:
536536
return self.models.base_model_paths[0].name
537537
return model_name
538+
539+
540+
def clamp_prompt_logprobs(
541+
prompt_logprobs: Union[PromptLogprobs,
542+
None]) -> Union[PromptLogprobs, None]:
543+
if prompt_logprobs is None:
544+
return prompt_logprobs
545+
546+
for logprob_dict in prompt_logprobs:
547+
if logprob_dict is None:
548+
continue
549+
for logprob_values in logprob_dict.values():
550+
if logprob_values.logprob == float('-inf'):
551+
logprob_values.logprob = -9999.0
552+
return prompt_logprobs

0 commit comments

Comments
 (0)