Skip to content

Commit ee18eb7

Browse files
committed
test logprobs
Signed-off-by: NickLucche <[email protected]>
1 parent e9f4438 commit ee18eb7

File tree

1 file changed

+36
-15
lines changed

1 file changed

+36
-15
lines changed

tests/v1/tpu/test_sampler.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,29 +62,50 @@ def test_sampler_different(model_name: str):
6262
# tokens match.
6363
assert output[0].outputs[0].text[:20] == output[1].outputs[0].text[:20]
6464

65+
6566
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
67+
# TODO TPU will appear busy if we fan-out test params here
68+
@pytest.mark.parametrize("n_prompts", [1])
6669
@pytest.mark.skipif(not current_platform.is_tpu(),
6770
reason="This test needs a TPU")
68-
def test_logprobs(model_name: str):
71+
def test_logprobs(model_name: str, n_prompts: int):
6972
"""
73+
Request top logprobs with different sampling settings and check
74+
that results contains the requested number, ordered ascendingly.
7075
"""
76+
77+
def check_num_logprobs(logprobs, expected_num: int):
78+
for step in logprobs:
79+
prev_logp = 1.0
80+
# order by rank
81+
sorted_step = dict(
82+
sorted(step.items(), key=lambda item: item[1].rank))
83+
84+
# Can contain the sampled token
85+
assert len(step) == expected_num or len(step) == expected_num + 1
86+
# Check results are ordered by prob value
87+
for rankno, (tid, logp) in enumerate(sorted_step.items()):
88+
assert logp.logprob <= prev_logp
89+
prev_logp = logp.logprob
90+
assert logp.rank == rankno + 1
91+
7192
llm = LLM(model_name,
7293
enforce_eager=False,
7394
max_num_seqs=1,
74-
max_model_len=512,
75-
max_num_batched_tokens=512)
95+
max_model_len=128,
96+
max_num_batched_tokens=128)
7697
prompts = [
7798
"Write a short story about a robot that dreams for the first time."
78-
]
79-
# Greedy sampling
80-
sampling_params = SamplingParams(temperature=0.0, max_tokens=64, logprobs=4)
81-
output = llm.generate(prompts, sampling_params)
82-
print(output)
99+
] * n_prompts
100+
greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64,\
101+
logprobs=4)
102+
regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
103+
logprobs=4)
104+
topkp_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
105+
logprobs=4, top_k=12, top_p=0.5)
83106

84-
sampling_params = SamplingParams(temperature=0.4, min_p=0.2, max_tokens=64, logprobs=4)
85-
output = llm.generate(prompts, sampling_params)
86-
print(output)
87-
88-
sampling_params = SamplingParams(temperature=0.4, min_p=0.2, max_tokens=64, logprobs=None)
89-
output = llm.generate(prompts, sampling_params)
90-
print(output)
107+
for sp in [greedy_sampling_params, regular_sampling_params, \
108+
topkp_sampling_params]:
109+
output = llm.generate(prompts, sp)
110+
for o in output:
111+
check_num_logprobs(o.outputs[0].logprobs, 4)

0 commit comments

Comments
 (0)