Skip to content

Commit 59073e3

Browse files
sahilsuneja1njhill
authored andcommitted
test for stop_reason
1 parent 24dd6da commit 59073e3

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

tests/samplers/test_stop_reason.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Test the different finish_reason="stop" situations during generation:
2+
1. One of the provided stop strings
3+
2. One of the provided stop tokens
4+
3. The EOS token
5+
6+
Run `pytest tests/samplers/test_stop_reason.py`.
7+
"""
8+
9+
import pytest
10+
import transformers
11+
12+
from vllm import SamplingParams
13+
14+
MODEL = "facebook/opt-350m"
15+
STOP_STR = "."
16+
SEED = 42
17+
MAX_TOKENS = 1024
18+
19+
20+
@pytest.fixture
21+
def vllm_model(vllm_runner):
22+
vllm_model = vllm_runner(MODEL)
23+
yield vllm_model
24+
del vllm_model
25+
26+
27+
def test_stop_reason(vllm_model, example_prompts):
28+
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL)
29+
stop_token_id = tokenizer.convert_tokens_to_ids(STOP_STR)
30+
llm = vllm_model.model
31+
32+
# test stop token
33+
outputs = llm.generate(example_prompts,
34+
sampling_params=SamplingParams(
35+
seed=SEED,
36+
max_tokens=MAX_TOKENS,
37+
stop_token_ids=[stop_token_id]))
38+
for output in outputs:
39+
output = output.outputs[0]
40+
assert output.finish_reason == "stop"
41+
assert output.stop_reason == stop_token_id
42+
43+
# test stop string
44+
outputs = llm.generate(example_prompts,
45+
sampling_params=SamplingParams(
46+
seed=SEED, max_tokens=MAX_TOKENS, stop="."))
47+
for output in outputs:
48+
output = output.outputs[0]
49+
assert output.finish_reason == "stop"
50+
assert output.stop_reason == STOP_STR
51+
52+
# test EOS token
53+
outputs = llm.generate(example_prompts,
54+
sampling_params=SamplingParams(
55+
seed=SEED, max_tokens=MAX_TOKENS))
56+
for output in outputs:
57+
output = output.outputs[0]
58+
assert output.finish_reason == "length" or (
59+
output.finish_reason == "stop" and output.stop_reason is None)

0 commit comments

Comments
 (0)