Skip to content

Commit b87d575

Browse files
authored
Modify run_batch_async (quantumlib#6387)
* Created using Colaboratory * call run_sweep_async in parallel * call circuits asynchronously * Revert "Created using Colaboratory" This reverts commit eb10318. * Revert "call run_sweep_async in parallel" This reverts commit f4e0d88. * revert colab * lint * use pmap * add test * lint * lint * nits
1 parent b99f633 commit b87d575

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

cirq/work/sampler.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,9 @@ async def run_batch_async(
294294
See docs for `cirq.Sampler.run_batch`.
295295
"""
296296
params_list, repetitions = self._normalize_batch_args(programs, params_list, repetitions)
297-
return [
298-
await self.run_sweep_async(circuit, params=params, repetitions=repetitions)
299-
for circuit, params, repetitions in zip(programs, params_list, repetitions)
300-
]
297+
return await duet.pstarmap_async(
298+
self.run_sweep_async, zip(programs, params_list, repetitions)
299+
)
301300

302301
def _normalize_batch_args(
303302
self,

cirq/work/sampler_test.py

+25
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,31 @@ def test_sampler_run_batch():
211211
assert np.array_equal(result.measurements['m'], np.array([[0], [0]], dtype='uint8'))
212212

213213

214+
@duet.sync
215+
async def test_run_batch_async_calls_run_sweep_asynchronously():
216+
"""Test run_batch_async calls run_sweep_async without waiting."""
217+
finished = []
218+
a = cirq.LineQubit(0)
219+
circuit1 = cirq.Circuit(cirq.X(a) ** sympy.Symbol('t'), cirq.measure(a, key='m'))
220+
circuit2 = cirq.Circuit(cirq.Y(a) ** sympy.Symbol('t'), cirq.measure(a, key='m'))
221+
params1 = cirq.Points('t', [0.3, 0.7])
222+
params2 = cirq.Points('t', [0.4, 0.6])
223+
params_list = [params1, params2]
224+
225+
class AsyncSampler(cirq.Sampler):
226+
async def run_sweep_async(self, program, params, repetitions: int = 1):
227+
if params == params1:
228+
await duet.sleep(0.001)
229+
230+
finished.append(params)
231+
232+
await AsyncSampler().run_batch_async(
233+
[circuit1, circuit2], params_list=params_list, repetitions=[1, 2]
234+
)
235+
236+
assert finished == list(reversed(params_list))
237+
238+
214239
def test_sampler_run_batch_default_params_and_repetitions():
215240
sampler = cirq.ZerosSampler()
216241
a = cirq.LineQubit(0)

0 commit comments

Comments
 (0)