@@ -129,12 +129,16 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed,
129
129
check_answers (indices , answer , test_texts )
130
130
131
131
132
- def prep_prompts (batch_size : int ):
132
+ def prep_prompts (batch_size : int , ln_range : tuple [ int , int ] = ( 800 , 1100 ) ):
133
133
"""
134
134
Generate prompts which a bunch of assignments,
135
135
then asking for the value of one of them.
136
136
The prompt is just under 10k tokens; sliding window is 4k
137
137
so the answer is outside sliding window, but should still be correct.
138
+
139
+ Args:
140
+ batch_size: number of prompts to generate
141
+ ln_range: an argument to control the length of the prompt
138
142
"""
139
143
prompts : list [str ] = []
140
144
answer : list [int ] = []
@@ -145,7 +149,7 @@ def prep_prompts(batch_size: int):
145
149
indices .append (idx )
146
150
prompt = "```python\n # We set a number of variables, " + \
147
151
f"x{ idx } will be important later\n "
148
- ln = random .randint (800 , 1100 )
152
+ ln = random .randint (* ln_range )
149
153
for k in range (30 , ln ):
150
154
v = random .randint (10 , 99 )
151
155
if k == idx :
@@ -157,7 +161,10 @@ def prep_prompts(batch_size: int):
157
161
return prompts , answer , indices
158
162
159
163
160
- def check_answers (indices : list [int ], answer : list [int ], outputs : list [str ]):
164
+ def check_answers (indices : list [int ],
165
+ answer : list [int ],
166
+ outputs : list [str ],
167
+ accept_rate : float = 0.7 ):
161
168
answer2 = [int (text [0 :2 ].strip ()) for text in outputs ]
162
169
print (list (zip (indices , zip (answer , answer2 ))))
163
170
numok = 0
@@ -166,7 +173,7 @@ def check_answers(indices: list[int], answer: list[int], outputs: list[str]):
166
173
numok += 1
167
174
frac_ok = numok / len (answer )
168
175
print (f"Num OK: { numok } /{ len (answer )} { frac_ok } " )
169
- assert frac_ok > 0.7
176
+ assert frac_ok >= accept_rate
170
177
171
178
172
179
def check_window (prompts : list [str ]):
0 commit comments