Skip to content

Commit bc3ea46

Browse files
njhillsahilsuneja1
authored and
jimpang
committed
[Misc] Include matched stop string/token in responses (vllm-project#2976)
Co-authored-by: Sahil Suneja <[email protected]>
1 parent 19d7628 commit bc3ea46

File tree

7 files changed

+97
-7
lines changed

7 files changed

+97
-7
lines changed

tests/samplers/test_stop_reason.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Test the different finish_reason="stop" situations during generation:
2+
1. One of the provided stop strings
3+
2. One of the provided stop tokens
4+
3. The EOS token
5+
6+
Run `pytest tests/samplers/test_stop_reason.py`.
7+
"""
8+
9+
import pytest
10+
import transformers
11+
12+
from vllm import SamplingParams
13+
14+
MODEL = "facebook/opt-350m"
15+
STOP_STR = "."
16+
SEED = 42
17+
MAX_TOKENS = 1024
18+
19+
20+
@pytest.fixture
21+
def vllm_model(vllm_runner):
22+
vllm_model = vllm_runner(MODEL)
23+
yield vllm_model
24+
del vllm_model
25+
26+
27+
def test_stop_reason(vllm_model, example_prompts):
28+
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL)
29+
stop_token_id = tokenizer.convert_tokens_to_ids(STOP_STR)
30+
llm = vllm_model.model
31+
32+
# test stop token
33+
outputs = llm.generate(example_prompts,
34+
sampling_params=SamplingParams(
35+
seed=SEED,
36+
max_tokens=MAX_TOKENS,
37+
stop_token_ids=[stop_token_id]))
38+
for output in outputs:
39+
output = output.outputs[0]
40+
assert output.finish_reason == "stop"
41+
assert output.stop_reason == stop_token_id
42+
43+
# test stop string
44+
outputs = llm.generate(example_prompts,
45+
sampling_params=SamplingParams(
46+
seed=SEED, max_tokens=MAX_TOKENS, stop="."))
47+
for output in outputs:
48+
output = output.outputs[0]
49+
assert output.finish_reason == "stop"
50+
assert output.stop_reason == STOP_STR
51+
52+
# test EOS token
53+
outputs = llm.generate(example_prompts,
54+
sampling_params=SamplingParams(
55+
seed=SEED, max_tokens=MAX_TOKENS))
56+
for output in outputs:
57+
output = output.outputs[0]
58+
assert output.finish_reason == "length" or (
59+
output.finish_reason == "stop" and output.stop_reason is None)

vllm/engine/llm_engine.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -740,12 +740,15 @@ def _check_stop(self, seq: Sequence,
740740
if seq.output_text.endswith(stop_str):
741741
self._finalize_sequence(seq, sampling_params, stop_str)
742742
seq.status = SequenceStatus.FINISHED_STOPPED
743+
seq.stop_reason = stop_str
743744
return
744-
if seq.get_last_token_id() in sampling_params.stop_token_ids:
745+
last_token_id = seq.get_last_token_id()
746+
if last_token_id in sampling_params.stop_token_ids:
745747
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
746-
seq.get_last_token_id())
748+
last_token_id)
747749
self._finalize_sequence(seq, sampling_params, stop_str)
748750
seq.status = SequenceStatus.FINISHED_STOPPED
751+
seq.stop_reason = last_token_id
749752
return
750753

751754
# Check if the sequence has generated the EOS token.

vllm/entrypoints/openai/protocol.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,13 @@ class CompletionResponseChoice(BaseModel):
338338
text: str
339339
logprobs: Optional[LogProbs] = None
340340
finish_reason: Optional[Literal["stop", "length"]] = None
341+
stop_reason: Union[None, int, str] = Field(
342+
default=None,
343+
description=(
344+
"The stop string or token id that caused the completion "
345+
"to stop, None if the completion finished for some other reason "
346+
"including encountering the EOS token"),
347+
)
341348

342349

343350
class CompletionResponse(BaseModel):
@@ -354,6 +361,13 @@ class CompletionResponseStreamChoice(BaseModel):
354361
text: str
355362
logprobs: Optional[LogProbs] = None
356363
finish_reason: Optional[Literal["stop", "length"]] = None
364+
stop_reason: Union[None, int, str] = Field(
365+
default=None,
366+
description=(
367+
"The stop string or token id that caused the completion "
368+
"to stop, None if the completion finished for some other reason "
369+
"including encountering the EOS token"),
370+
)
357371

358372

359373
class CompletionStreamResponse(BaseModel):
@@ -375,6 +389,7 @@ class ChatCompletionResponseChoice(BaseModel):
375389
message: ChatMessage
376390
logprobs: Optional[LogProbs] = None
377391
finish_reason: Optional[Literal["stop", "length"]] = None
392+
stop_reason: Union[None, int, str] = None
378393

379394

380395
class ChatCompletionResponse(BaseModel):
@@ -396,6 +411,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
396411
delta: DeltaMessage
397412
logprobs: Optional[LogProbs] = None
398413
finish_reason: Optional[Literal["stop", "length"]] = None
414+
stop_reason: Union[None, int, str] = None
399415

400416

401417
class ChatCompletionStreamResponse(BaseModel):

vllm/entrypoints/openai/serving_chat.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,8 @@ async def chat_completion_stream_generator(
220220
index=i,
221221
delta=DeltaMessage(content=delta_text),
222222
logprobs=logprobs,
223-
finish_reason=output.finish_reason)
223+
finish_reason=output.finish_reason,
224+
stop_reason=output.stop_reason)
224225
chunk = ChatCompletionStreamResponse(
225226
id=request_id,
226227
object=chunk_object_type,
@@ -278,6 +279,7 @@ async def chat_completion_full_generator(
278279
message=ChatMessage(role=role, content=output.text),
279280
logprobs=logprobs,
280281
finish_reason=output.finish_reason,
282+
stop_reason=output.stop_reason,
281283
)
282284
choices.append(choice_data)
283285

vllm/entrypoints/openai/serving_completion.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ async def completion_stream_generator(
266266
previous_texts[i] = output.text
267267
previous_num_tokens[i] = len(output.token_ids)
268268
finish_reason = output.finish_reason
269+
stop_reason = output.stop_reason
269270
if output.finish_reason is not None: # return final usage
270271
prompt_tokens = len(res.prompt_token_ids)
271272
completion_tokens = len(output.token_ids)
@@ -286,6 +287,7 @@ async def completion_stream_generator(
286287
text=delta_text,
287288
logprobs=logprobs,
288289
finish_reason=finish_reason,
290+
stop_reason=stop_reason,
289291
)
290292
],
291293
usage=final_usage,
@@ -342,6 +344,7 @@ def request_output_to_completion_response(
342344
text=output_text,
343345
logprobs=logprobs,
344346
finish_reason=output.finish_reason,
347+
stop_reason=output.stop_reason,
345348
)
346349
choices.append(choice_data)
347350

vllm/outputs.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from typing import List, Optional
2+
from typing import List, Optional, Union
33

44
from vllm.lora.request import LoRARequest
55
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
@@ -18,6 +18,9 @@ class CompletionOutput:
1818
logprobs: The log probabilities of the top probability words at each
1919
position if the logprobs are requested.
2020
finish_reason: The reason why the sequence is finished.
21+
stop_reason: The stop string or token id that caused the completion
22+
to stop, None if the completion finished for some other reason
23+
including encountering the EOS token.
2124
lora_request: The LoRA request that was used to generate the output.
2225
"""
2326

@@ -29,6 +32,7 @@ def __init__(
2932
cumulative_logprob: float,
3033
logprobs: Optional[SampleLogprobs],
3134
finish_reason: Optional[str] = None,
35+
stop_reason: Union[int, str, None] = None,
3236
lora_request: Optional[LoRARequest] = None,
3337
) -> None:
3438
self.index = index
@@ -37,6 +41,7 @@ def __init__(
3741
self.cumulative_logprob = cumulative_logprob
3842
self.logprobs = logprobs
3943
self.finish_reason = finish_reason
44+
self.stop_reason = stop_reason
4045
self.lora_request = lora_request
4146

4247
def finished(self) -> bool:
@@ -48,7 +53,8 @@ def __repr__(self) -> str:
4853
f"token_ids={self.token_ids}, "
4954
f"cumulative_logprob={self.cumulative_logprob}, "
5055
f"logprobs={self.logprobs}, "
51-
f"finish_reason={self.finish_reason})")
56+
f"finish_reason={self.finish_reason}, "
57+
f"stop_reason={self.stop_reason})")
5258

5359

5460
class RequestOutput:
@@ -111,8 +117,8 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
111117
seq.get_output_token_ids(),
112118
seq.get_cumulative_logprob(),
113119
seq.output_logprobs if include_logprobs else None,
114-
SequenceStatus.get_finished_reason(seq.status))
115-
for seq in top_n_seqs
120+
SequenceStatus.get_finished_reason(seq.status),
121+
seq.stop_reason) for seq in top_n_seqs
116122
]
117123

118124
# Every sequence in the sequence group should have the same prompt.

vllm/sequence.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def __init__(
183183
# Initialize the logical token blocks with the prompt token ids.
184184
self._append_tokens_to_blocks(prompt_token_ids)
185185
self.status = SequenceStatus.WAITING
186+
self.stop_reason: Union[int, str, None] = None
186187

187188
# Used for incremental detokenization
188189
self.prefix_offset = 0

0 commit comments

Comments
 (0)