Skip to content

Commit 9b713dc

Browse files
maffoorht
authored andcommitted
Add async methods to AbstractEngine and AbstractJob (quantumlib#5555)
Review: @95-martin-orion
1 parent 3da63e0 commit 9b713dc

21 files changed

+222
-190
lines changed

cirq-core/cirq/work/sampler.py

+48-11
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
# limitations under the License.
1414
"""Abstract base class for things sampling quantum circuits."""
1515

16-
import abc
1716
import collections
1817
from typing import Dict, FrozenSet, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
1918

19+
import duet
2020
import pandas as pd
2121

22-
from cirq import ops, protocols, study
22+
from cirq import ops, protocols, study, value
2323
from cirq.work.observable_measurement import (
2424
measure_observables,
2525
RepetitionsStoppingCriteria,
@@ -31,7 +31,7 @@
3131
import cirq
3232

3333

34-
class Sampler(metaclass=abc.ABCMeta):
34+
class Sampler(metaclass=value.ABCMetaImplementAnyOneOf):
3535
"""Something capable of sampling quantum circuits. Simulator or hardware."""
3636

3737
def run(
@@ -177,7 +177,19 @@ def sample(
177177

178178
return pd.concat(results)
179179

180-
@abc.abstractmethod
180+
def _run_sweep_impl(
181+
self, program: 'cirq.AbstractCircuit', params: 'cirq.Sweepable', repetitions: int = 1
182+
) -> Sequence['cirq.Result']:
183+
"""Implements run_sweep using run_sweep_async"""
184+
return duet.run(self.run_sweep_async, program, params, repetitions)
185+
186+
async def _run_sweep_async_impl(
187+
self, program: 'cirq.AbstractCircuit', params: 'cirq.Sweepable', repetitions: int = 1
188+
) -> Sequence['cirq.Result']:
189+
"""Implements run_sweep_async using run_sweep"""
190+
return self.run_sweep(program, params=params, repetitions=repetitions)
191+
192+
@value.alternative(requires='run_sweep_async', implementation=_run_sweep_impl)
181193
def run_sweep(
182194
self, program: 'cirq.AbstractCircuit', params: 'cirq.Sweepable', repetitions: int = 1
183195
) -> Sequence['cirq.Result']:
@@ -200,6 +212,7 @@ def run_sweep(
200212
Result list for this run; one for each possible parameter resolver.
201213
"""
202214

215+
@value.alternative(requires='run_sweep', implementation=_run_sweep_async_impl)
203216
async def run_sweep_async(
204217
self, program: 'cirq.AbstractCircuit', params: 'cirq.Sweepable', repetitions: int = 1
205218
) -> Sequence['cirq.Result']:
@@ -217,13 +230,12 @@ async def run_sweep_async(
217230
Returns:
218231
Result list for this run; one for each possible parameter resolver.
219232
"""
220-
return self.run_sweep(program, params=params, repetitions=repetitions)
221233

222234
def run_batch(
223235
self,
224236
programs: Sequence['cirq.AbstractCircuit'],
225-
params_list: Optional[List['cirq.Sweepable']] = None,
226-
repetitions: Union[int, List[int]] = 1,
237+
params_list: Optional[Sequence['cirq.Sweepable']] = None,
238+
repetitions: Union[int, Sequence[int]] = 1,
227239
) -> Sequence[Sequence['cirq.Result']]:
228240
"""Runs the supplied circuits.
229241
@@ -263,6 +275,34 @@ def run_batch(
263275
ValueError: If length of `programs` is not equal to the length
264276
of `params_list` or the length of `repetitions`.
265277
"""
278+
params_list, repetitions = self._normalize_batch_args(programs, params_list, repetitions)
279+
return [
280+
self.run_sweep(circuit, params=params, repetitions=repetitions)
281+
for circuit, params, repetitions in zip(programs, params_list, repetitions)
282+
]
283+
284+
async def run_batch_async(
285+
self,
286+
programs: Sequence['cirq.AbstractCircuit'],
287+
params_list: Optional[Sequence['cirq.Sweepable']] = None,
288+
repetitions: Union[int, Sequence[int]] = 1,
289+
) -> Sequence[Sequence['cirq.Result']]:
290+
"""Runs the supplied circuits.
291+
292+
This is an asynchronous version of `run_batch`; see full docs there.
293+
"""
294+
params_list, repetitions = self._normalize_batch_args(programs, params_list, repetitions)
295+
return [
296+
await self.run_sweep_async(circuit, params=params, repetitions=repetitions)
297+
for circuit, params, repetitions in zip(programs, params_list, repetitions)
298+
]
299+
300+
def _normalize_batch_args(
301+
self,
302+
programs: Sequence['cirq.AbstractCircuit'],
303+
params_list: Optional[Sequence['cirq.Sweepable']] = None,
304+
repetitions: Union[int, Sequence[int]] = 1,
305+
) -> Tuple[Sequence['cirq.Sweepable'], Sequence[int]]:
266306
if params_list is None:
267307
params_list = [None] * len(programs)
268308
if len(programs) != len(params_list):
@@ -277,10 +317,7 @@ def run_batch(
277317
'len(programs) and len(repetitions) must match. '
278318
f'Got {len(programs)} and {len(repetitions)}.'
279319
)
280-
return [
281-
self.run_sweep(circuit, params=params, repetitions=repetitions)
282-
for circuit, params, repetitions in zip(programs, params_list, repetitions)
283-
]
320+
return params_list, repetitions
284321

285322
def sample_expectation_values(
286323
self,

cirq-core/cirq/work/sampler_test.py

+36
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,42 @@ def run_sweep(self, program, params, repetitions: int = 1):
5959
await FailingSampler().run_sweep_async(cirq.Circuit(), repetitions=1, params=None)
6060

6161

62+
def test_run_sweep_impl():
63+
"""Test run_sweep implemented in terms of run_sweep_async."""
64+
65+
class AsyncSampler(cirq.Sampler):
66+
async def run_sweep_async(self, program, params, repetitions: int = 1):
67+
await duet.sleep(0.001)
68+
return cirq.Simulator().run_sweep(program, params, repetitions)
69+
70+
results = AsyncSampler().run_sweep(
71+
cirq.Circuit(cirq.measure(cirq.GridQubit(0, 0), key='m')),
72+
cirq.Linspace('foo', 0, 1, 10),
73+
repetitions=10,
74+
)
75+
assert len(results) == 10
76+
for result in results:
77+
np.testing.assert_equal(result.records['m'], np.zeros((10, 1, 1)))
78+
79+
80+
@duet.sync
81+
async def test_run_sweep_async_impl():
82+
"""Test run_sweep_async implemented in terms of run_sweep."""
83+
84+
class SyncSampler(cirq.Sampler):
85+
def run_sweep(self, program, params, repetitions: int = 1):
86+
return cirq.Simulator().run_sweep(program, params, repetitions)
87+
88+
results = await SyncSampler().run_sweep_async(
89+
cirq.Circuit(cirq.measure(cirq.GridQubit(0, 0), key='m')),
90+
cirq.Linspace('foo', 0, 1, 10),
91+
repetitions=10,
92+
)
93+
assert len(results) == 10
94+
for result in results:
95+
np.testing.assert_equal(result.records['m'], np.zeros((10, 1, 1)))
96+
97+
6298
def test_sampler_sample_multiple_params():
6399
a, b = cirq.LineQubit.range(2)
64100
s = sympy.Symbol('s')

cirq-core/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# functools.cached_property was introduced in python 3.8
44
backports.cached_property~=1.0.1; python_version < '3.8'
55

6-
duet~=0.2.6
6+
duet~=0.2.7
77
matplotlib~=3.0
88
networkx~=2.4
99
numpy~=1.16

cirq-google/cirq_google/engine/abstract_job.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import abc
1717
from typing import Dict, Iterator, List, Optional, overload, Sequence, Tuple, TYPE_CHECKING
1818

19+
import duet
20+
1921
import cirq
2022
from cirq_google.cloud import quantum
2123
from cirq_google.engine.engine_result import EngineResult
@@ -162,25 +164,31 @@ def delete(self) -> Optional[bool]:
162164
"""Deletes the job and result, if any."""
163165

164166
@abc.abstractmethod
165-
def batched_results(self) -> Sequence[Sequence[EngineResult]]:
167+
async def batched_results_async(self) -> Sequence[Sequence[EngineResult]]:
166168
"""Returns the job results, blocking until the job is complete.
167169
168170
This method is intended for batched jobs. Instead of flattening
169171
results into a single list, this will return a List[Result]
170172
for each circuit in the batch.
171173
"""
172174

175+
batched_results = duet.sync(batched_results_async)
176+
173177
@abc.abstractmethod
174-
def results(self) -> Sequence[EngineResult]:
178+
async def results_async(self) -> Sequence[EngineResult]:
175179
"""Returns the job results, blocking until the job is complete."""
176180

181+
results = duet.sync(results_async)
182+
177183
@abc.abstractmethod
178-
def calibration_results(self) -> Sequence['calibration_result.CalibrationResult']:
184+
async def calibration_results_async(self) -> Sequence['calibration_result.CalibrationResult']:
179185
"""Returns the results of a run_calibration() call.
180186
181187
This function will fail if any other type of results were returned.
182188
"""
183189

190+
calibration_results = duet.sync(calibration_results_async)
191+
184192
def __iter__(self) -> Iterator[cirq.Result]:
185193
yield from self.results()
186194

cirq-google/cirq_google/engine/abstract_job_test.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,13 @@ def cancel(self) -> None:
8282
def delete(self) -> None:
8383
pass
8484

85-
def batched_results(self):
85+
async def batched_results_async(self):
8686
pass
8787

88-
def results(self):
89-
return list(
90-
cirq.ResultDict(params={}, measurements={'a': np.asarray([t])}) for t in range(5)
91-
)
88+
async def results_async(self):
89+
return [cirq.ResultDict(params={}, measurements={'a': np.asarray([t])}) for t in range(5)]
9290

93-
def calibration_results(self):
91+
async def calibration_results_async(self):
9492
pass
9593

9694

cirq-google/cirq_google/engine/abstract_local_engine_test.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,13 @@ def health(self, *args, **kwargs):
5353
def list_calibrations(self, *args, **kwargs):
5454
pass
5555

56-
def run(self, *args, **kwargs):
56+
async def run_batch_async(self, *args, **kwargs):
5757
pass
5858

59-
def run_batch(self, *args, **kwargs):
59+
async def run_calibration_async(self, *args, **kwargs):
6060
pass
6161

62-
def run_calibration(self, *args, **kwargs):
63-
pass
64-
65-
def run_sweep(self, *args, **kwargs):
62+
async def run_sweep_async(self, *args, **kwargs):
6663
pass
6764

6865
def get_sampler(self, *args, **kwargs):

cirq-google/cirq_google/engine/abstract_local_job_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ def cancel(self) -> None:
4141
def delete(self) -> None:
4242
pass
4343

44-
def batched_results(self) -> Sequence[Sequence[EngineResult]]:
44+
async def batched_results_async(self) -> Sequence[Sequence[EngineResult]]:
4545
return [] # coverage: ignore
4646

47-
def results(self) -> Sequence[EngineResult]:
47+
async def results_async(self) -> Sequence[EngineResult]:
4848
return [] # coverage: ignore
4949

50-
def calibration_results(self) -> Sequence[CalibrationResult]:
50+
async def calibration_results_async(self) -> Sequence[CalibrationResult]:
5151
return [] # coverage: ignore
5252

5353

cirq-google/cirq_google/engine/abstract_local_processor_test.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,13 @@ def health(self, *args, **kwargs):
5252
def list_calibrations(self, *args, **kwargs):
5353
pass
5454

55-
def run(self, *args, **kwargs):
55+
async def run_batch_async(self, *args, **kwargs):
5656
pass
5757

58-
def run_batch(self, *args, **kwargs):
58+
async def run_calibration_async(self, *args, **kwargs):
5959
pass
6060

61-
def run_calibration(self, *args, **kwargs):
62-
pass
63-
64-
def run_sweep(self, *args, **kwargs):
61+
async def run_sweep_async(self, *args, **kwargs):
6562
pass
6663

6764
def get_sampler(self, *args, **kwargs):

cirq-google/cirq_google/engine/abstract_processor.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020

2121
import abc
2222
import datetime
23-
2423
from typing import Dict, Iterable, List, Optional, Sequence, TYPE_CHECKING, Union
2524

26-
import cirq
25+
import duet
2726

27+
import cirq
2828
from cirq_google.api import v2
2929
from cirq_google.cloud import quantum
3030
from cirq_google.engine import calibration, util
@@ -53,7 +53,7 @@ class AbstractProcessor(abc.ABC):
5353
This is an abstract class. Inheritors should implement abstract methods.
5454
"""
5555

56-
def run(
56+
async def run_async(
5757
self,
5858
program: cirq.Circuit,
5959
program_id: Optional[str] = None,
@@ -88,9 +88,23 @@ def run(
8888
Returns:
8989
A single Result for this run.
9090
"""
91+
job = await self.run_sweep_async(
92+
program=program,
93+
program_id=program_id,
94+
job_id=job_id,
95+
params=[param_resolver or cirq.ParamResolver({})],
96+
repetitions=repetitions,
97+
program_description=program_description,
98+
program_labels=program_labels,
99+
job_description=job_description,
100+
job_labels=job_labels,
101+
)
102+
return job.results()[0]
103+
104+
run = duet.sync(run_async)
91105

92106
@abc.abstractmethod
93-
def run_sweep(
107+
async def run_sweep_async(
94108
self,
95109
program: cirq.AbstractCircuit,
96110
program_id: Optional[str] = None,
@@ -129,8 +143,10 @@ def run_sweep(
129143
`cirq.Result`, one for each parameter sweep.
130144
"""
131145

146+
run_sweep = duet.sync(run_sweep_async)
147+
132148
@abc.abstractmethod
133-
def run_batch(
149+
async def run_batch_async(
134150
self,
135151
programs: Sequence[cirq.AbstractCircuit],
136152
program_id: Optional[str] = None,
@@ -180,8 +196,10 @@ def run_batch(
180196
parameter sweep.
181197
"""
182198

199+
run_batch = duet.sync(run_batch_async)
200+
183201
@abc.abstractmethod
184-
def run_calibration(
202+
async def run_calibration_async(
185203
self,
186204
layers: List['cg.CalibrationLayer'],
187205
program_id: Optional[str] = None,
@@ -223,6 +241,8 @@ def run_calibration(
223241
calibration_results().
224242
"""
225243

244+
run_calibration = duet.sync(run_calibration_async)
245+
226246
@abc.abstractmethod
227247
def get_sampler(self) -> 'cg.ProcessorSampler':
228248
"""Returns a sampler backed by the processor."""

0 commit comments

Comments
 (0)