@@ -62,29 +62,50 @@ def test_sampler_different(model_name: str):
62
62
# tokens match.
63
63
assert output [0 ].outputs [0 ].text [:20 ] == output [1 ].outputs [0 ].text [:20 ]
64
64
65
+
65
66
@pytest .mark .parametrize ("model_name" , ["Qwen/Qwen2.5-1.5B-Instruct" ])
67
+ # TODO TPU will appear busy if we fan-out test params here
68
+ @pytest .mark .parametrize ("n_prompts" , [1 ])
66
69
@pytest .mark .skipif (not current_platform .is_tpu (),
67
70
reason = "This test needs a TPU" )
68
- def test_logprobs (model_name : str ):
71
+ def test_logprobs (model_name : str , n_prompts : int ):
69
72
"""
73
+ Request top logprobs with different sampling settings and check
74
+ that results contains the requested number, ordered ascendingly.
70
75
"""
76
+
77
+ def check_num_logprobs (logprobs , expected_num : int ):
78
+ for step in logprobs :
79
+ prev_logp = 1.0
80
+ # order by rank
81
+ sorted_step = dict (
82
+ sorted (step .items (), key = lambda item : item [1 ].rank ))
83
+
84
+ # Can contain the sampled token
85
+ assert len (step ) == expected_num or len (step ) == expected_num + 1
86
+ # Check results are ordered by prob value
87
+ for rankno , (tid , logp ) in enumerate (sorted_step .items ()):
88
+ assert logp .logprob <= prev_logp
89
+ prev_logp = logp .logprob
90
+ assert logp .rank == rankno + 1
91
+
71
92
llm = LLM (model_name ,
72
93
enforce_eager = False ,
73
94
max_num_seqs = 1 ,
74
- max_model_len = 512 ,
75
- max_num_batched_tokens = 512 )
95
+ max_model_len = 128 ,
96
+ max_num_batched_tokens = 128 )
76
97
prompts = [
77
98
"Write a short story about a robot that dreams for the first time."
78
- ]
79
- # Greedy sampling
80
- sampling_params = SamplingParams (temperature = 0.0 , max_tokens = 64 , logprobs = 4 )
81
- output = llm .generate (prompts , sampling_params )
82
- print (output )
99
+ ] * n_prompts
100
+ greedy_sampling_params = SamplingParams (temperature = 0.0 , max_tokens = 64 ,\
101
+ logprobs = 4 )
102
+ regular_sampling_params = SamplingParams (temperature = 0.4 , max_tokens = 64 ,\
103
+ logprobs = 4 )
104
+ topkp_sampling_params = SamplingParams (temperature = 0.4 , max_tokens = 64 ,\
105
+ logprobs = 4 , top_k = 12 , top_p = 0.5 )
83
106
84
- sampling_params = SamplingParams (temperature = 0.4 , min_p = 0.2 , max_tokens = 64 , logprobs = 4 )
85
- output = llm .generate (prompts , sampling_params )
86
- print (output )
87
-
88
- sampling_params = SamplingParams (temperature = 0.4 , min_p = 0.2 , max_tokens = 64 , logprobs = None )
89
- output = llm .generate (prompts , sampling_params )
90
- print (output )
107
+ for sp in [greedy_sampling_params , regular_sampling_params , \
108
+ topkp_sampling_params ]:
109
+ output = llm .generate (prompts , sp )
110
+ for o in output :
111
+ check_num_logprobs (o .outputs [0 ].logprobs , 4 )
0 commit comments