@@ -211,6 +211,31 @@ def test_sampler_run_batch():
211
211
assert np .array_equal (result .measurements ['m' ], np .array ([[0 ], [0 ]], dtype = 'uint8' ))
212
212
213
213
214
+ @duet .sync
215
+ async def test_run_batch_async_calls_run_sweep_asynchronously ():
216
+ """Test run_batch_async calls run_sweep_async without waiting."""
217
+ finished = []
218
+ a = cirq .LineQubit (0 )
219
+ circuit1 = cirq .Circuit (cirq .X (a ) ** sympy .Symbol ('t' ), cirq .measure (a , key = 'm' ))
220
+ circuit2 = cirq .Circuit (cirq .Y (a ) ** sympy .Symbol ('t' ), cirq .measure (a , key = 'm' ))
221
+ params1 = cirq .Points ('t' , [0.3 , 0.7 ])
222
+ params2 = cirq .Points ('t' , [0.4 , 0.6 ])
223
+ params_list = [params1 , params2 ]
224
+
225
+ class AsyncSampler (cirq .Sampler ):
226
+ async def run_sweep_async (self , program , params , repetitions : int = 1 ):
227
+ if params == params1 :
228
+ await duet .sleep (0.001 )
229
+
230
+ finished .append (params )
231
+
232
+ await AsyncSampler ().run_batch_async (
233
+ [circuit1 , circuit2 ], params_list = params_list , repetitions = [1 , 2 ]
234
+ )
235
+
236
+ assert finished == list (reversed (params_list ))
237
+
238
+
214
239
def test_sampler_run_batch_default_params_and_repetitions ():
215
240
sampler = cirq .ZerosSampler ()
216
241
a = cirq .LineQubit (0 )
0 commit comments