Skip to content

Commit b5b3b74

Browse files
authored
ProcessorSampler: route run_batch to run_sweep (#6357)
1 parent 2fdb447 commit b5b3b74

File tree

2 files changed

+44
-41
lines changed

2 files changed

+44
-41
lines changed

cirq-google/cirq_google/engine/processor_sampler.py

-21
Original file line numberDiff line numberDiff line change
@@ -76,27 +76,6 @@ async def run_batch_async(
7676
params_list: Optional[Sequence[cirq.Sweepable]] = None,
7777
repetitions: Union[int, Sequence[int]] = 1,
7878
) -> Sequence[Sequence['cg.EngineResult']]:
79-
"""Runs the supplied circuits.
80-
81-
In order to gain a speedup from using this method instead of other run
82-
methods, the following conditions must be satisfied:
83-
1. All circuits must measure the same set of qubits.
84-
2. The number of circuit repetitions must be the same for all
85-
circuits. That is, the `repetitions` argument must be an integer,
86-
or else a list with identical values.
87-
"""
88-
params_list, repetitions = self._normalize_batch_args(programs, params_list, repetitions)
89-
if len(set(repetitions)) == 1:
90-
# All repetitions are the same so batching can be done efficiently
91-
job = await self._processor.run_batch_async(
92-
programs=programs,
93-
params_list=params_list,
94-
repetitions=repetitions[0],
95-
run_name=self._run_name,
96-
device_config_name=self._device_config_name,
97-
)
98-
return await job.batched_results_async()
99-
# Varying number of repetitions so no speedup
10079
return cast(
10180
Sequence[Sequence['cg.EngineResult']],
10281
await super().run_batch_async(programs, params_list, repetitions),

cirq-google/cirq_google/engine/processor_sampler_test.py

+44-20
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,28 @@ def test_run_batch(run_name, device_config_name):
5454
circuit2 = cirq.Circuit(cirq.Y(a))
5555
params1 = [cirq.ParamResolver({'t': 1})]
5656
params2 = [cirq.ParamResolver({'t': 2})]
57-
circuits = [circuit1, circuit2]
58-
params_list = [params1, params2]
59-
sampler.run_batch(circuits, params_list, 5)
60-
processor.run_batch_async.assert_called_with(
61-
params_list=params_list,
62-
programs=circuits,
63-
repetitions=5,
64-
run_name=run_name,
65-
device_config_name=device_config_name,
66-
)
57+
58+
sampler.run_batch([circuit1, circuit2], [params1, params2], 5)
59+
60+
expected_calls = [
61+
mock.call(
62+
program=circuit1,
63+
params=params1,
64+
repetitions=5,
65+
run_name=run_name,
66+
device_config_name=device_config_name,
67+
),
68+
mock.call().results_async(),
69+
mock.call(
70+
program=circuit2,
71+
params=params2,
72+
repetitions=5,
73+
run_name=run_name,
74+
device_config_name=device_config_name,
75+
),
76+
mock.call().results_async(),
77+
]
78+
processor.run_sweep_async.assert_has_calls(expected_calls)
6779

6880

6981
@pytest.mark.parametrize(
@@ -79,16 +91,28 @@ def test_run_batch_identical_repetitions(run_name, device_config_name):
7991
circuit2 = cirq.Circuit(cirq.Y(a))
8092
params1 = [cirq.ParamResolver({'t': 1})]
8193
params2 = [cirq.ParamResolver({'t': 2})]
82-
circuits = [circuit1, circuit2]
83-
params_list = [params1, params2]
84-
sampler.run_batch(circuits, params_list, [5, 5])
85-
processor.run_batch_async.assert_called_with(
86-
params_list=params_list,
87-
programs=circuits,
88-
repetitions=5,
89-
run_name=run_name,
90-
device_config_name=device_config_name,
91-
)
94+
95+
sampler.run_batch([circuit1, circuit2], [params1, params2], [5, 5])
96+
97+
expected_calls = [
98+
mock.call(
99+
program=circuit1,
100+
params=params1,
101+
repetitions=5,
102+
run_name=run_name,
103+
device_config_name=device_config_name,
104+
),
105+
mock.call().results_async(),
106+
mock.call(
107+
program=circuit2,
108+
params=params2,
109+
repetitions=5,
110+
run_name=run_name,
111+
device_config_name=device_config_name,
112+
),
113+
mock.call().results_async(),
114+
]
115+
processor.run_sweep_async.assert_has_calls(expected_calls)
92116

93117

94118
def test_run_batch_bad_number_of_repetitions():

0 commit comments

Comments
 (0)