|
| 1 | +"""Test the different finish_reason="stop" situations during generation: |
| 2 | + 1. One of the provided stop strings |
| 3 | + 2. One of the provided stop tokens |
| 4 | + 3. The EOS token |
| 5 | +
|
| 6 | +Run `pytest tests/samplers/test_stop_reason.py`. |
| 7 | +""" |
| 8 | + |
| 9 | +import pytest |
| 10 | +import transformers |
| 11 | + |
| 12 | +from vllm import SamplingParams |
| 13 | + |
| 14 | +MODEL = "facebook/opt-350m" |
| 15 | +STOP_STR = "." |
| 16 | +SEED = 42 |
| 17 | +MAX_TOKENS = 1024 |
| 18 | + |
| 19 | + |
| 20 | +@pytest.fixture |
| 21 | +def vllm_model(vllm_runner): |
| 22 | + vllm_model = vllm_runner(MODEL) |
| 23 | + yield vllm_model |
| 24 | + del vllm_model |
| 25 | + |
| 26 | + |
| 27 | +def test_stop_reason(vllm_model, example_prompts): |
| 28 | + tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL) |
| 29 | + stop_token_id = tokenizer.convert_tokens_to_ids(STOP_STR) |
| 30 | + llm = vllm_model.model |
| 31 | + |
| 32 | + # test stop token |
| 33 | + outputs = llm.generate(example_prompts, |
| 34 | + sampling_params=SamplingParams( |
| 35 | + seed=SEED, |
| 36 | + max_tokens=MAX_TOKENS, |
| 37 | + stop_token_ids=[stop_token_id])) |
| 38 | + for output in outputs: |
| 39 | + output = output.outputs[0] |
| 40 | + assert output.finish_reason == "stop" |
| 41 | + assert output.stop_reason == stop_token_id |
| 42 | + |
| 43 | + # test stop string |
| 44 | + outputs = llm.generate(example_prompts, |
| 45 | + sampling_params=SamplingParams( |
| 46 | + seed=SEED, max_tokens=MAX_TOKENS, stop=".")) |
| 47 | + for output in outputs: |
| 48 | + output = output.outputs[0] |
| 49 | + assert output.finish_reason == "stop" |
| 50 | + assert output.stop_reason == STOP_STR |
| 51 | + |
| 52 | + # test EOS token |
| 53 | + outputs = llm.generate(example_prompts, |
| 54 | + sampling_params=SamplingParams( |
| 55 | + seed=SEED, max_tokens=MAX_TOKENS)) |
| 56 | + for output in outputs: |
| 57 | + output = output.outputs[0] |
| 58 | + assert output.finish_reason == "length" or ( |
| 59 | + output.finish_reason == "stop" and output.stop_reason is None) |
0 commit comments