Skip to content

Add async support in EngineClient, EngineSampler, etc. #5219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions cirq-google/cirq_google/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import string
from typing import Dict, Iterable, List, Optional, Sequence, Set, TypeVar, Union, TYPE_CHECKING

import duet
import google.auth
from google.protobuf import any_pb2

Expand Down Expand Up @@ -493,7 +494,7 @@ def run_calibration(
)

@util.deprecated_gate_set_parameter
def create_program(
async def create_program_async(
self,
program: cirq.AbstractCircuit,
program_id: Optional[str] = None,
Expand Down Expand Up @@ -524,7 +525,7 @@ def create_program(
if not program_id:
program_id = _make_random_id('prog-')

new_program_id, new_program = self.context.client.create_program(
new_program_id, new_program = await self.context.client.create_program_async(
self.project_id,
program_id,
code=self.context._serialize_program(program, gate_set),
Expand All @@ -536,8 +537,10 @@ def create_program(
self.project_id, new_program_id, self.context, new_program
)

create_program = duet.sync(create_program_async)

@util.deprecated_gate_set_parameter
def create_batch_program(
async def create_batch_program_async(
self,
programs: Sequence[cirq.AbstractCircuit],
program_id: Optional[str] = None,
Expand Down Expand Up @@ -574,7 +577,7 @@ def create_batch_program(
for program in programs:
gate_set.serialize(program, msg=batch.programs.add())

new_program_id, new_program = self.context.client.create_program(
new_program_id, new_program = await self.context.client.create_program_async(
self.project_id,
program_id,
code=util.pack_any(batch),
Expand All @@ -586,8 +589,10 @@ def create_batch_program(
self.project_id, new_program_id, self.context, new_program, result_type=ResultType.Batch
)

create_batch_program = duet.sync(create_batch_program_async)

@util.deprecated_gate_set_parameter
def create_calibration_program(
async def create_calibration_program_async(
self,
layers: List['cirq_google.CalibrationLayer'],
program_id: Optional[str] = None,
Expand Down Expand Up @@ -632,7 +637,7 @@ def create_calibration_program(
arg_to_proto(layer.args[arg], out=new_layer.args[arg])
gate_set.serialize(layer.program, msg=new_layer.layer)

new_program_id, new_program = self.context.client.create_program(
new_program_id, new_program = await self.context.client.create_program_async(
self.project_id,
program_id,
code=util.pack_any(calibration),
Expand All @@ -648,6 +653,8 @@ def create_calibration_program(
result_type=ResultType.Calibration,
)

create_calibration_program = duet.sync(create_calibration_program_async)

def get_program(self, program_id: str) -> engine_program.EngineProgram:
"""Returns an EngineProgram for an existing Quantum Engine program.

Expand All @@ -659,7 +666,7 @@ def get_program(self, program_id: str) -> engine_program.EngineProgram:
"""
return engine_program.EngineProgram(self.project_id, program_id, self.context)

def list_programs(
async def list_programs_async(
self,
created_before: Optional[Union[datetime.datetime, datetime.date]] = None,
created_after: Optional[Union[datetime.datetime, datetime.date]] = None,
Expand All @@ -681,7 +688,7 @@ def list_programs(
"""

client = self.context.client
response = client.list_programs(
response = await client.list_programs_async(
self.project_id,
created_before=created_before,
created_after=created_after,
Expand All @@ -697,7 +704,9 @@ def list_programs(
for p in response
]

def list_jobs(
list_programs = duet.sync(list_programs_async)

async def list_jobs_async(
self,
created_before: Optional[Union[datetime.datetime, datetime.date]] = None,
created_after: Optional[Union[datetime.datetime, datetime.date]] = None,
Expand Down Expand Up @@ -730,7 +739,7 @@ def list_jobs(
`quantum.ExecutionStatus.State` enum for accepted values.
"""
client = self.context.client
response = client.list_jobs(
response = await client.list_jobs_async(
self.project_id,
None,
created_before=created_before,
Expand All @@ -749,7 +758,9 @@ def list_jobs(
for j in response
]

def list_processors(self) -> List[engine_processor.EngineProcessor]:
list_jobs = duet.sync(list_jobs_async)

async def list_processors_async(self) -> List[engine_processor.EngineProcessor]:
"""Returns a list of Processors that the user has visibility to in the
current Engine project. The names of these processors are used to
identify devices when scheduling jobs and gathering calibration metrics.
Expand All @@ -758,14 +769,16 @@ def list_processors(self) -> List[engine_processor.EngineProcessor]:
A list of EngineProcessors to access status, device and calibration
information.
"""
response = self.context.client.list_processors(self.project_id)
response = await self.context.client.list_processors_async(self.project_id)
return [
engine_processor.EngineProcessor(
self.project_id, engine_client._ids_from_processor_name(p.name)[1], self.context, p
)
for p in response
]

list_processors = duet.sync(list_processors_async)

def get_processor(self, processor_id: str) -> engine_processor.EngineProcessor:
"""Returns an EngineProcessor for a Quantum Engine processor.

Expand Down
73 changes: 38 additions & 35 deletions cirq-google/cirq_google/engine/engine_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
# limitations under the License.
"""A helper for jobs that have been created on the Quantum Engine."""
import datetime
import time

from typing import Dict, Iterator, List, Optional, overload, Sequence, Tuple, TYPE_CHECKING

import duet
from google.protobuf import any_pb2

import cirq
Expand Down Expand Up @@ -107,20 +107,25 @@ def program(self) -> 'engine_program.EngineProgram':

return engine_program.EngineProgram(self.project_id, self.program_id, self.context)

async def _get_job_async(self, return_run_context: bool = False) -> quantum.QuantumJob:
return await self.context.client.get_job_async(
self.project_id, self.program_id, self.job_id, return_run_context
)

_get_job = duet.sync(_get_job_async)

def _inner_job(self) -> quantum.QuantumJob:
if self._job is None:
self._job = self.context.client.get_job(
self.project_id, self.program_id, self.job_id, False
)
self._job = self._get_job()
return self._job

def _refresh_job(self) -> quantum.QuantumJob:
async def _refresh_job_async(self) -> quantum.QuantumJob:
if self._job is None or self._job.execution_status.state not in TERMINAL_STATES:
self._job = self.context.client.get_job(
self.project_id, self.program_id, self.job_id, False
)
self._job = await self._get_job_async()
return self._job

_refresh_job = duet.sync(_refresh_job_async)

def create_time(self) -> 'datetime.datetime':
"""Returns when the job was created."""
return self._inner_job().create_time
Expand Down Expand Up @@ -224,10 +229,7 @@ def get_repetitions_and_sweeps(self) -> Tuple[int, List[cirq.Sweep]]:
A tuple of the repetition count and list of sweeps.
"""
if self._job is None or self._job.run_context is None:
self._job = self.context.client.get_job(
self.project_id, self.program_id, self.job_id, True
)

self._job = self._get_job(return_run_context=True)
return _deserialize_run_context(self._job.run_context)

def get_processor(self) -> 'Optional[engine_processor.EngineProcessor]':
Expand Down Expand Up @@ -260,42 +262,26 @@ def delete(self) -> None:
"""Deletes the job and result, if any."""
self.context.client.delete_job(self.project_id, self.program_id, self.job_id)

def batched_results(self) -> Sequence[Sequence[EngineResult]]:
async def batched_results_async(self) -> Sequence[Sequence[EngineResult]]:
"""Returns the job results, blocking until the job is complete.

This method is intended for batched jobs. Instead of flattening
results into a single list, this will return a Sequence[Result]
for each circuit in the batch.
"""
self.results()
await self.results_async()
if self._batched_results is None:
raise ValueError('batched_results called for a non-batch result.')
return self._batched_results

def _wait_for_result(self):
job = self._refresh_job()
total_seconds_waited = 0.0
timeout = self.context.timeout
while True:
if timeout and total_seconds_waited >= timeout:
break
if job.execution_status.state in TERMINAL_STATES:
break
time.sleep(0.5)
total_seconds_waited += 0.5
job = self._refresh_job()
_raise_on_failure(job)
response = self.context.client.get_job_results(
self.project_id, self.program_id, self.job_id
)
return response.result
batched_results = duet.sync(batched_results_async)

def results(self) -> Sequence[EngineResult]:
async def results_async(self) -> Sequence[EngineResult]:
"""Returns the job results, blocking until the job is complete."""
import cirq_google.engine.engine as engine_base

if self._results is None:
result = self._wait_for_result()
result = await self._await_result_async()
result_type = result.type_url[len(engine_base.TYPE_PREFIX) :]
if (
result_type == 'cirq.google.api.v1.Result'
Expand All @@ -317,7 +303,22 @@ def results(self) -> Sequence[EngineResult]:
raise ValueError(f'invalid result proto version: {result_type}')
return self._results

def calibration_results(self) -> Sequence[CalibrationResult]:
results = duet.sync(results_async)

async def _await_result_async(self) -> quantum.QuantumResult:
async with duet.timeout_scope(self.context.timeout):
while True:
job = await self._refresh_job_async()
if job.execution_status.state in TERMINAL_STATES:
break
await duet.sleep(0.5)
_raise_on_failure(job)
response = await self.context.client.get_job_results_async(
self.project_id, self.program_id, self.job_id
)
return response.result

async def calibration_results_async(self) -> Sequence[CalibrationResult]:
"""Returns the results of a run_calibration() call.

This function will fail if any other type of results were returned
Expand All @@ -326,7 +327,7 @@ def calibration_results(self) -> Sequence[CalibrationResult]:
import cirq_google.engine.engine as engine_base

if self._calibration_results is None:
result = self._wait_for_result()
result = await self._await_result_async()
result_type = result.type_url[len(engine_base.TYPE_PREFIX) :]
if result_type != 'cirq.google.api.v2.FocusedCalibrationResult':
raise ValueError(f'Did not find calibration results, instead found: {result_type}')
Expand All @@ -343,6 +344,8 @@ def calibration_results(self) -> Sequence[CalibrationResult]:
self._calibration_results = cal_results
return self._calibration_results

calibration_results = duet.sync(calibration_results_async)

def _get_job_results_v1(self, result: v1.program_pb2.Result) -> Sequence[EngineResult]:
# coverage: ignore
job_id = self.id()
Expand Down
Loading