Skip to content

Commit 87f77be

Browse files
authored
Create separate unary and streaming RPC tests in engine_test.py (#6311)
1 parent 4dc36d5 commit 87f77be

File tree

1 file changed

+228
-15
lines changed

1 file changed

+228
-15
lines changed

cirq-google/cirq_google/engine/engine_test.py

+228-15
Original file line numberDiff line numberDiff line change
@@ -355,10 +355,50 @@ def setup_run_circuit_with_result_(client, result):
355355

356356

357357
@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
358-
def test_run_circuit(client):
358+
def test_run_circuit_with_unary_rpcs(client):
359359
setup_run_circuit_with_result_(client, _A_RESULT)
360360

361-
engine = cg.Engine(project_id='proj', service_args={'client_info': 1})
361+
engine = cg.Engine(
362+
project_id='proj',
363+
context=EngineContext(service_args={'client_info': 1}, enable_streaming=False),
364+
)
365+
result = engine.run(
366+
program=_CIRCUIT, program_id='prog', job_id='job-id', processor_ids=['mysim']
367+
)
368+
369+
assert result.repetitions == 1
370+
assert result.params.param_dict == {'a': 1}
371+
assert result.measurements == {'q': np.array([[0]], dtype='uint8')}
372+
client.assert_called_with(service_args={'client_info': 1}, verbose=None)
373+
client().create_program_async.assert_called_once()
374+
client().create_job_async.assert_called_once_with(
375+
project_id='proj',
376+
program_id='prog',
377+
job_id='job-id',
378+
processor_ids=['mysim'],
379+
run_context=util.pack_any(
380+
v2.run_context_pb2.RunContext(
381+
parameter_sweeps=[v2.run_context_pb2.ParameterSweep(repetitions=1)]
382+
)
383+
),
384+
description=None,
385+
labels=None,
386+
processor_id='',
387+
run_name='',
388+
device_config_name='',
389+
)
390+
client().get_job_async.assert_called_once_with('proj', 'prog', 'job-id', False)
391+
client().get_job_results_async.assert_called_once_with('proj', 'prog', 'job-id')
392+
393+
394+
@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
395+
def test_run_circuit_with_stream_rpcs(client):
396+
setup_run_circuit_with_result_(client, _A_RESULT)
397+
398+
engine = cg.Engine(
399+
project_id='proj',
400+
context=EngineContext(service_args={'client_info': 1}, enable_streaming=True),
401+
)
362402
result = engine.run(
363403
program=_CIRCUIT, program_id='prog', job_id='job-id', processor_ids=['mysim']
364404
)
@@ -399,7 +439,37 @@ def test_unsupported_program_type():
399439

400440

401441
@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
402-
def test_run_circuit_failed(client):
442+
def test_run_circuit_failed_with_unary_rpcs(client):
443+
client().create_program_async.return_value = (
444+
'prog',
445+
quantum.QuantumProgram(name='projects/proj/programs/prog'),
446+
)
447+
client().create_job_async.return_value = (
448+
'job-id',
449+
quantum.QuantumJob(
450+
name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'}
451+
),
452+
)
453+
client().get_job_async.return_value = quantum.QuantumJob(
454+
name='projects/proj/programs/prog/jobs/job-id',
455+
execution_status={
456+
'state': 'FAILURE',
457+
'processor_name': 'myqc',
458+
'failure': {'error_code': 'SYSTEM_ERROR', 'error_message': 'Not good'},
459+
},
460+
)
461+
462+
engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=False))
463+
with pytest.raises(
464+
RuntimeError,
465+
match='Job projects/proj/programs/prog/jobs/job-id on processor'
466+
' myqc failed. SYSTEM_ERROR: Not good',
467+
):
468+
engine.run(program=_CIRCUIT)
469+
470+
471+
@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
472+
def test_run_circuit_failed_with_stream_rpcs(client):
403473
failed_job = quantum.QuantumJob(
404474
name='projects/proj/programs/prog/jobs/job-id',
405475
execution_status={
@@ -412,7 +482,7 @@ def test_run_circuit_failed(client):
412482
stream_future.try_set_result(failed_job)
413483
client().run_job_over_stream.return_value = stream_future
414484

415-
engine = cg.Engine(project_id='proj')
485+
engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=True))
416486
with pytest.raises(
417487
RuntimeError,
418488
match='Job projects/proj/programs/prog/jobs/job-id on processor'
@@ -422,7 +492,36 @@ def test_run_circuit_failed(client):
422492

423493

424494
@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
425-
def test_run_circuit_failed_missing_processor_name(client):
495+
def test_run_circuit_failed_missing_processor_name_with_unary_rpcs(client):
496+
client().create_program_async.return_value = (
497+
'prog',
498+
quantum.QuantumProgram(name='projects/proj/programs/prog'),
499+
)
500+
client().create_job_async.return_value = (
501+
'job-id',
502+
quantum.QuantumJob(
503+
name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'}
504+
),
505+
)
506+
client().get_job_async.return_value = quantum.QuantumJob(
507+
name='projects/proj/programs/prog/jobs/job-id',
508+
execution_status={
509+
'state': 'FAILURE',
510+
'failure': {'error_code': 'SYSTEM_ERROR', 'error_message': 'Not good'},
511+
},
512+
)
513+
514+
engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=False))
515+
with pytest.raises(
516+
RuntimeError,
517+
match='Job projects/proj/programs/prog/jobs/job-id on processor'
518+
' UNKNOWN failed. SYSTEM_ERROR: Not good',
519+
):
520+
engine.run(program=_CIRCUIT)
521+
522+
523+
@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
524+
def test_run_circuit_failed_missing_processor_name_with_stream_rpcs(client):
426525
failed_job = quantum.QuantumJob(
427526
name='projects/proj/programs/prog/jobs/job-id',
428527
execution_status={
@@ -434,7 +533,7 @@ def test_run_circuit_failed_missing_processor_name(client):
434533
stream_future.try_set_result(failed_job)
435534
client().run_job_over_stream.return_value = stream_future
436535

437-
engine = cg.Engine(project_id='proj')
536+
engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=True))
438537
with pytest.raises(
439538
RuntimeError,
440539
match='Job projects/proj/programs/prog/jobs/job-id on processor'
@@ -444,26 +543,78 @@ def test_run_circuit_failed_missing_processor_name(client):
444543

445544

446545
@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
447-
def test_run_circuit_cancelled(client):
546+
def test_run_circuit_cancelled_with_unary_rpcs(client):
547+
client().create_program_async.return_value = (
548+
'prog',
549+
quantum.QuantumProgram(name='projects/proj/programs/prog'),
550+
)
551+
client().create_job_async.return_value = (
552+
'job-id',
553+
quantum.QuantumJob(
554+
name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'}
555+
),
556+
)
557+
client().get_job_async.return_value = quantum.QuantumJob(
558+
name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'CANCELLED'}
559+
)
560+
561+
engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=False))
562+
with pytest.raises(
563+
RuntimeError, match='Job projects/proj/programs/prog/jobs/job-id failed in state CANCELLED.'
564+
):
565+
engine.run(program=_CIRCUIT)
566+
567+
568+
@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
569+
def test_run_circuit_cancelled_with_stream_rpcs(client):
448570
canceled_job = quantum.QuantumJob(
449571
name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'CANCELLED'}
450572
)
451573
stream_future = duet.AwaitableFuture()
452574
stream_future.try_set_result(canceled_job)
453575
client().run_job_over_stream.return_value = stream_future
454576

455-
engine = cg.Engine(project_id='proj')
577+
engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=True))
456578
with pytest.raises(
457579
RuntimeError, match='Job projects/proj/programs/prog/jobs/job-id failed in state CANCELLED.'
458580
):
459581
engine.run(program=_CIRCUIT)
460582

461583

462584
@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
463-
def test_run_sweep_params(client):
585+
def test_run_sweep_params_with_unary_rpcs(client):
586+
setup_run_circuit_with_result_(client, _RESULTS)
587+
588+
engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=False))
589+
job = engine.run_sweep(
590+
program=_CIRCUIT, params=[cirq.ParamResolver({'a': 1}), cirq.ParamResolver({'a': 2})]
591+
)
592+
results = job.results()
593+
assert len(results) == 2
594+
for i, v in enumerate([1, 2]):
595+
assert results[i].repetitions == 1
596+
assert results[i].params.param_dict == {'a': v}
597+
assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')}
598+
599+
client().create_program_async.assert_called_once()
600+
client().create_job_async.assert_called_once()
601+
602+
run_context = v2.run_context_pb2.RunContext()
603+
client().create_job_async.call_args[1]['run_context'].Unpack(run_context)
604+
sweeps = run_context.parameter_sweeps
605+
assert len(sweeps) == 2
606+
for i, v in enumerate([1.0, 2.0]):
607+
assert sweeps[i].repetitions == 1
608+
assert sweeps[i].sweep.sweep_function.sweeps[0].single_sweep.points.points == [v]
609+
client().get_job_async.assert_called_once()
610+
client().get_job_results_async.assert_called_once()
611+
612+
613+
@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
614+
def test_run_sweep_params_with_stream_rpcs(client):
464615
setup_run_circuit_with_result_(client, _RESULTS)
465616

466-
engine = cg.Engine(project_id='proj')
617+
engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=True))
467618
job = engine.run_sweep(
468619
program=_CIRCUIT, params=[cirq.ParamResolver({'a': 1}), cirq.ParamResolver({'a': 2})]
469620
)
@@ -486,7 +637,12 @@ def test_run_sweep_params(client):
486637

487638

488639
def test_run_sweep_with_multiple_processor_ids():
489-
engine = cg.Engine(project_id='proj', proto_version=cg.engine.engine.ProtoVersion.V2)
640+
engine = cg.Engine(
641+
project_id='proj',
642+
context=EngineContext(
643+
proto_version=cg.engine.engine.ProtoVersion.V2, enable_streaming=True
644+
),
645+
)
490646
with pytest.raises(ValueError, match='multiple processors is no longer supported'):
491647
_ = engine.run_sweep(
492648
program=_CIRCUIT,
@@ -527,10 +683,44 @@ def test_run_multiple_times(client):
527683

528684

529685
@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
530-
def test_run_sweep_v2(client):
686+
def test_run_sweep_v2_with_unary_rpcs(client):
531687
setup_run_circuit_with_result_(client, _RESULTS_V2)
532688

533-
engine = cg.Engine(project_id='proj', proto_version=cg.engine.engine.ProtoVersion.V2)
689+
engine = cg.Engine(
690+
project_id='proj',
691+
context=EngineContext(
692+
proto_version=cg.engine.engine.ProtoVersion.V2, enable_streaming=False
693+
),
694+
)
695+
job = engine.run_sweep(program=_CIRCUIT, job_id='job-id', params=cirq.Points('a', [1, 2]))
696+
results = job.results()
697+
assert len(results) == 2
698+
for i, v in enumerate([1, 2]):
699+
assert results[i].repetitions == 1
700+
assert results[i].params.param_dict == {'a': v}
701+
assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')}
702+
client().create_program_async.assert_called_once()
703+
client().create_job_async.assert_called_once()
704+
run_context = v2.run_context_pb2.RunContext()
705+
client().create_job_async.call_args[1]['run_context'].Unpack(run_context)
706+
sweeps = run_context.parameter_sweeps
707+
assert len(sweeps) == 1
708+
assert sweeps[0].repetitions == 1
709+
assert sweeps[0].sweep.single_sweep.points.points == [1, 2]
710+
client().get_job_async.assert_called_once()
711+
client().get_job_results_async.assert_called_once()
712+
713+
714+
@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
715+
def test_run_sweep_v2_with_stream_rpcs(client):
716+
setup_run_circuit_with_result_(client, _RESULTS_V2)
717+
718+
engine = cg.Engine(
719+
project_id='proj',
720+
context=EngineContext(
721+
proto_version=cg.engine.engine.ProtoVersion.V2, enable_streaming=True
722+
),
723+
)
534724
job = engine.run_sweep(program=_CIRCUIT, job_id='job-id', params=cirq.Points('a', [1, 2]))
535725
results = job.results()
536726
assert len(results) == 2
@@ -772,10 +962,33 @@ def test_get_processor():
772962

773963

774964
@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
775-
def test_sampler(client):
965+
def test_sampler_with_unary_rpcs(client):
966+
setup_run_circuit_with_result_(client, _RESULTS)
967+
968+
engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=False))
969+
sampler = engine.get_sampler(processor_id='tmp')
970+
results = sampler.run_sweep(
971+
program=_CIRCUIT, params=[cirq.ParamResolver({'a': 1}), cirq.ParamResolver({'a': 2})]
972+
)
973+
assert len(results) == 2
974+
for i, v in enumerate([1, 2]):
975+
assert results[i].repetitions == 1
976+
assert results[i].params.param_dict == {'a': v}
977+
assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')}
978+
assert client().create_program_async.call_args[0][0] == 'proj'
979+
980+
with cirq.testing.assert_deprecated('sampler', deadline='1.0'):
981+
_ = engine.sampler(processor_id='tmp')
982+
983+
with pytest.raises(ValueError, match='list of processors'):
984+
_ = engine.get_sampler(['test1', 'test2'])
985+
986+
987+
@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
988+
def test_sampler_with_stream_rpcs(client):
776989
setup_run_circuit_with_result_(client, _RESULTS)
777990

778-
engine = cg.Engine(project_id='proj')
991+
engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=True))
779992
sampler = engine.get_sampler(processor_id='tmp')
780993
results = sampler.run_sweep(
781994
program=_CIRCUIT, params=[cirq.ParamResolver({'a': 1}), cirq.ParamResolver({'a': 2})]

0 commit comments

Comments
 (0)