Skip to content

Commit c79f8e3

Browse files
njhilljoerunde
authored andcommitted
Changes pending from vllm-project/vllm#2976
Include matched stop string/token in responses [Cherry-picked from open upstream PR vllm-project/vllm#2976] Currently a finish_reason of "stop" is returned if any of the following are encountered: - One of the provided stop strings - One of the provided stop tokens - The EOS token It can be useful to know specifically which of these caused the sequence generation to stop, especially since by default the stop strings/tokens are omitted from the output text (and output token_ids?). This PR adds a "stop_reason" field to the CompletionOutput class which will contain the matched stop string or integer token id. It will be None otherwise, including the EOS token case. This means in particular that EOS can be inferred by (finish_reason=="stop" and stop_reason=None). I've also added to the openai server responses but not sure whether or not this should be included since it isn't part of the official API. Signed-off-by: Joe Runde <[email protected]> Signed-off-by: Joe Runde <[email protected]>
1 parent 9bdd013 commit c79f8e3

File tree

6 files changed

+25
-5
lines changed

6 files changed

+25
-5
lines changed

vllm/engine/llm_engine.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,12 +1013,15 @@ def _check_stop(self, seq: Sequence,
10131013
if seq.output_text.endswith(stop_str):
10141014
self._finalize_sequence(seq, sampling_params, stop_str)
10151015
seq.status = SequenceStatus.FINISHED_STOPPED
1016+
seq.stop_reason = stop_str
10161017
return
1017-
if seq.get_last_token_id() in sampling_params.stop_token_ids:
1018+
last_token_id = seq.get_last_token_id()
1019+
if last_token_id in sampling_params.stop_token_ids:
10181020
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
1019-
seq.get_last_token_id())
1021+
last_token_id)
10201022
self._finalize_sequence(seq, sampling_params, stop_str)
10211023
seq.status = SequenceStatus.FINISHED_STOPPED
1024+
seq.stop_reason = last_token_id
10221025
return
10231026

10241027
# Check if the sequence has generated the EOS token.

vllm/entrypoints/openai/protocol.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ class CompletionResponseChoice(BaseModel):
254254
text: str
255255
logprobs: Optional[LogProbs] = None
256256
finish_reason: Optional[Literal["stop", "length"]] = None
257+
stop_reason: Union[None, int, str] = None
257258

258259

259260
class CompletionResponse(BaseModel):
@@ -270,6 +271,7 @@ class CompletionResponseStreamChoice(BaseModel):
270271
text: str
271272
logprobs: Optional[LogProbs] = None
272273
finish_reason: Optional[Literal["stop", "length"]] = None
274+
stop_reason: Union[None, int, str] = None
273275

274276

275277
class CompletionStreamResponse(BaseModel):
@@ -291,6 +293,7 @@ class ChatCompletionResponseChoice(BaseModel):
291293
message: ChatMessage
292294
logprobs: Optional[LogProbs] = None
293295
finish_reason: Optional[Literal["stop", "length"]] = None
296+
stop_reason: Union[None, int, str] = None
294297

295298

296299
class ChatCompletionResponse(BaseModel):
@@ -312,6 +315,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
312315
delta: DeltaMessage
313316
logprobs: Optional[LogProbs] = None
314317
finish_reason: Optional[Literal["stop", "length"]] = None
318+
stop_reason: Union[None, int, str] = None
315319

316320

317321
class ChatCompletionStreamResponse(BaseModel):

vllm/entrypoints/openai/serving_chat.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ async def chat_completion_stream_generator(
213213
index=i,
214214
delta=DeltaMessage(content=delta_text),
215215
logprobs=logprobs,
216-
finish_reason=output.finish_reason)
216+
finish_reason=output.finish_reason,
217+
stop_reason=output.stop_reason)
217218
chunk = ChatCompletionStreamResponse(
218219
id=request_id,
219220
object=chunk_object_type,
@@ -271,6 +272,7 @@ async def chat_completion_full_generator(
271272
message=ChatMessage(role=role, content=output.text),
272273
logprobs=logprobs,
273274
finish_reason=output.finish_reason,
275+
stop_reason=output.stop_reason,
274276
)
275277
choices.append(choice_data)
276278

vllm/entrypoints/openai/serving_completion.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ async def completion_stream_generator(
261261
previous_texts[i] = output.text
262262
previous_num_tokens[i] = len(output.token_ids)
263263
finish_reason = output.finish_reason
264+
stop_reason = output.stop_reason
264265
response_json = CompletionStreamResponse(
265266
id=request_id,
266267
created=created_time,
@@ -271,6 +272,7 @@ async def completion_stream_generator(
271272
text=delta_text,
272273
logprobs=logprobs,
273274
finish_reason=finish_reason,
275+
stop_reason=stop_reason,
274276
)
275277
]).model_dump_json()
276278
yield f"data: {response_json}\n\n"
@@ -295,6 +297,7 @@ async def completion_stream_generator(
295297
text="",
296298
logprobs=logprobs,
297299
finish_reason=output.finish_reason,
300+
stop_reason=output.stop_reason,
298301
)
299302
],
300303
usage=final_usage,
@@ -354,6 +357,7 @@ def request_output_to_completion_response(
354357
text=output_text,
355358
logprobs=logprobs,
356359
finish_reason=output.finish_reason,
360+
stop_reason=output.stop_reason,
357361
)
358362
choices.append(choice_data)
359363

vllm/outputs.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional
1+
from typing import List, Optional, Union
22
import time
33

44
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup,
@@ -18,6 +18,9 @@ class CompletionOutput:
1818
logprobs: The log probabilities of the top probability words at each
1919
position if the logprobs are requested.
2020
finish_reason: The reason why the sequence is finished.
21+
stop_reason: The stop string or token id that caused the completion to stop,
22+
None if the completion finished for some other reason including
23+
encountering the EOS token.
2124
lora_request: The LoRA request that was used to generate the output.
2225
"""
2326

@@ -29,6 +32,7 @@ def __init__(
2932
cumulative_logprob: float,
3033
logprobs: Optional[SampleLogprobs],
3134
finish_reason: Optional[str] = None,
35+
stop_reason: Union[int, str, None] = None,
3236
lora_request: Optional[LoRARequest] = None,
3337
) -> None:
3438
self.index = index
@@ -37,6 +41,7 @@ def __init__(
3741
self.cumulative_logprob = cumulative_logprob
3842
self.logprobs = logprobs
3943
self.finish_reason = finish_reason
44+
self.stop_reason = stop_reason
4045
self.lora_request = lora_request
4146

4247
def finished(self) -> bool:
@@ -48,7 +53,8 @@ def __repr__(self) -> str:
4853
f"token_ids={self.token_ids}, "
4954
f"cumulative_logprob={self.cumulative_logprob}, "
5055
f"logprobs={self.logprobs}, "
51-
f"finish_reason={self.finish_reason})")
56+
f"finish_reason={self.finish_reason}, "
57+
f"stop_reason={self.stop_reason})")
5258

5359

5460
class RequestOutput:

vllm/sequence.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def __init__(
159159
# Initialize the logical token blocks with the prompt token ids.
160160
self._append_tokens_to_blocks(prompt_token_ids)
161161
self.status = SequenceStatus.WAITING
162+
self.stop_reason: Union[int, str, None] = None
162163

163164
# Used for incremental detokenization
164165
self.prefix_offset = 0

0 commit comments

Comments
 (0)