Skip to content

Commit 001acd1

Browse files
authored
Add synchronize_terminal_measurements transformer to replace SynchronizeTerminalMeasurements (quantumlib#4911)
- Part of quantumlib#4722 - Follows the new Transformer API quantumlib#4483 - Supports no compile tags NoCompile Tag for optimizers quantumlib#4253 - Fixes quantumlib#4907
1 parent fbb9b71 commit 001acd1

6 files changed

+324
-3
lines changed

cirq/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@
371371
single_qubit_matrix_to_phased_x_z,
372372
single_qubit_matrix_to_phxz,
373373
single_qubit_op_to_framed_phase_form,
374+
synchronize_terminal_measurements,
374375
TRANSFORMER,
375376
TransformerContext,
376377
TransformerLogger,

cirq/optimizers/synchronize_terminal_measurements.py

+2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515

1616
from typing import List, Set, Tuple, cast
1717
from cirq import circuits, ops, protocols
18+
from cirq._compat import deprecated_class
1819

1920

21+
@deprecated_class(deadline='v1.0', fix='Use cirq.synchronize_terminal_measurements instead.')
2022
class SynchronizeTerminalMeasurements:
2123
"""Move measurements to the end of the circuit.
2224

cirq/optimizers/synchronize_terminal_measurements_test.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616

1717

1818
def assert_optimizes(before, after, measure_only_moment=True):
19-
opt = cirq.SynchronizeTerminalMeasurements(measure_only_moment)
20-
opt(before)
21-
assert before == after
19+
with cirq.testing.assert_deprecated(
20+
"Use cirq.synchronize_terminal_measurements", deadline='v1.0'
21+
):
22+
opt = cirq.SynchronizeTerminalMeasurements(measure_only_moment)
23+
opt(before)
24+
assert before == after
2225

2326

2427
def test_no_move():

cirq/transformers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343

4444
from cirq.transformers.align import align_left, align_right
4545

46+
from cirq.transformers.synchronize_terminal_measurements import synchronize_terminal_measurements
47+
4648
from cirq.transformers.transformer_api import (
4749
LogLevel,
4850
TRANSFORMER,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2022 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Transformer pass to move terminal measurements to the end of circuit."""
16+
17+
from typing import List, Optional, Set, Tuple, TYPE_CHECKING
18+
from cirq import protocols, ops
19+
from cirq.transformers import transformer_api
20+
21+
if TYPE_CHECKING:
22+
import cirq
23+
24+
25+
def find_terminal_measurements(
26+
circuit: 'cirq.AbstractCircuit',
27+
) -> List[Tuple[int, 'cirq.Operation']]:
28+
"""Finds all terminal measurements in the given circuit.
29+
30+
A measurement is terminal if there are no other operations acting on the measured qubits
31+
after the measurement operation occurs in the circuit.
32+
33+
Args:
34+
circuit: The circuit to find terminal measurements in.
35+
36+
Returns:
37+
List of terminal measurements, each specified as (moment_index, measurement_operation).
38+
"""
39+
40+
open_qubits: Set['cirq.Qid'] = set(circuit.all_qubits())
41+
seen_control_keys: Set['cirq.MeasurementKey'] = set()
42+
terminal_measurements: List[Tuple[int, 'cirq.Operation']] = []
43+
for i in range(len(circuit) - 1, -1, -1):
44+
moment = circuit[i]
45+
for q in open_qubits:
46+
op = moment.operation_at(q)
47+
seen_control_keys |= protocols.control_keys(op)
48+
if (
49+
op is not None
50+
and open_qubits.issuperset(op.qubits)
51+
and protocols.is_measurement(op)
52+
and not (seen_control_keys & protocols.measurement_key_objs(op))
53+
):
54+
terminal_measurements.append((i, op))
55+
open_qubits -= moment.qubits
56+
if not open_qubits:
57+
break
58+
return terminal_measurements
59+
60+
61+
@transformer_api.transformer
62+
def synchronize_terminal_measurements(
63+
circuit: 'cirq.AbstractCircuit',
64+
*,
65+
context: Optional['cirq.TransformerContext'] = None,
66+
after_other_operations: bool = True,
67+
) -> 'cirq.Circuit':
68+
"""Move measurements to the end of the circuit.
69+
70+
Move all measurements in a circuit to the final moment, if it can accommodate them (without
71+
overlapping with other operations). If `after_other_operations` is true, then a new moment will
72+
be added to the end of the circuit containing all the measurements that should be brought
73+
forward.
74+
75+
Args:
76+
circuit: Input circuit to transform.
77+
context: `cirq.TransformerContext` storing common configurable options for transformers.
78+
after_other_operations: Set by default. If the circuit's final moment contains
79+
non-measurement operations and this is set then a new empty moment is appended to
80+
the circuit before pushing measurements to the end.
81+
Returns:
82+
Copy of the transformed input circuit.
83+
"""
84+
if context is None:
85+
context = transformer_api.TransformerContext()
86+
terminal_measurements = [
87+
(i, op)
88+
for i, op in find_terminal_measurements(circuit)
89+
if set(op.tags).isdisjoint(context.ignore_tags)
90+
]
91+
ret = circuit.unfreeze(copy=True)
92+
if not terminal_measurements:
93+
return ret
94+
95+
ret.batch_remove(terminal_measurements)
96+
if ret[-1] and after_other_operations:
97+
ret.append(ops.Moment())
98+
ret[-1] = ret[-1].with_operations(op for _, op in terminal_measurements)
99+
return ret
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
# Copyright 2022 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import cirq
16+
17+
NO_COMPILE_TAG = "no_compile_tag"
18+
19+
20+
def assert_optimizes(before, after, measure_only_moment=True, with_context=False):
21+
transformed_circuit = (
22+
cirq.synchronize_terminal_measurements(before, after_other_operations=measure_only_moment)
23+
if not with_context
24+
else cirq.synchronize_terminal_measurements(
25+
before,
26+
context=cirq.TransformerContext(ignore_tags=(NO_COMPILE_TAG,)),
27+
after_other_operations=measure_only_moment,
28+
)
29+
)
30+
cirq.testing.assert_same_circuits(transformed_circuit, after)
31+
32+
33+
def test_no_move():
34+
q1 = cirq.NamedQubit('q1')
35+
before = cirq.Circuit([cirq.Moment([cirq.H(q1)])])
36+
after = before
37+
assert_optimizes(before=before, after=after)
38+
39+
40+
def test_simple_align():
41+
q1 = cirq.NamedQubit('q1')
42+
q2 = cirq.NamedQubit('q2')
43+
before = cirq.Circuit(
44+
[
45+
cirq.Moment([cirq.H(q1), cirq.H(q2)]),
46+
cirq.Moment([cirq.measure(q1).with_tags(NO_COMPILE_TAG), cirq.Z(q2)]),
47+
cirq.Moment([cirq.measure(q2)]),
48+
]
49+
)
50+
after = cirq.Circuit(
51+
[
52+
cirq.Moment([cirq.H(q1), cirq.H(q2)]),
53+
cirq.Moment([cirq.Z(q2)]),
54+
cirq.Moment([cirq.measure(q1).with_tags(NO_COMPILE_TAG), cirq.measure(q2)]),
55+
]
56+
)
57+
assert_optimizes(before=before, after=after)
58+
assert_optimizes(before=before, after=before, with_context=True)
59+
60+
61+
def test_simple_partial_align():
62+
q1 = cirq.NamedQubit('q1')
63+
q2 = cirq.NamedQubit('q2')
64+
before = cirq.Circuit(
65+
[
66+
cirq.Moment([cirq.measure(q1), cirq.Z(q2)]),
67+
cirq.Moment([cirq.Z(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG)]),
68+
]
69+
)
70+
after = cirq.Circuit(
71+
[
72+
cirq.Moment([cirq.measure(q1), cirq.Z(q2)]),
73+
cirq.Moment([cirq.Z(q1)]),
74+
cirq.Moment([cirq.measure(q2).with_tags(NO_COMPILE_TAG)]),
75+
]
76+
)
77+
assert_optimizes(before=before, after=after)
78+
assert_optimizes(before=before, after=before, with_context=True)
79+
80+
81+
def test_slide_forward_one():
82+
q1 = cirq.NamedQubit('q1')
83+
q2 = cirq.NamedQubit('q2')
84+
q3 = cirq.NamedQubit('q3')
85+
before = cirq.Circuit(
86+
[
87+
cirq.Moment([cirq.H(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG), cirq.measure(q3)]),
88+
]
89+
)
90+
after = cirq.Circuit(
91+
[
92+
cirq.Moment([cirq.H(q1)]),
93+
cirq.Moment([cirq.measure(q2).with_tags(NO_COMPILE_TAG), cirq.measure(q3)]),
94+
]
95+
)
96+
after_no_compile = cirq.Circuit(
97+
[
98+
cirq.Moment([cirq.H(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG)]),
99+
cirq.Moment([cirq.measure(q3)]),
100+
]
101+
)
102+
assert_optimizes(before=before, after=after)
103+
assert_optimizes(before=before, after=after_no_compile, with_context=True)
104+
105+
106+
def test_no_slide_forward_one():
107+
q1 = cirq.NamedQubit('q1')
108+
q2 = cirq.NamedQubit('q2')
109+
q3 = cirq.NamedQubit('q3')
110+
before = cirq.Circuit(
111+
[
112+
cirq.Moment([cirq.H(q1), cirq.measure(q2), cirq.measure(q3)]),
113+
]
114+
)
115+
after = cirq.Circuit(
116+
[
117+
cirq.Moment([cirq.H(q1), cirq.measure(q2), cirq.measure(q3)]),
118+
]
119+
)
120+
assert_optimizes(before=before, after=after, measure_only_moment=False)
121+
122+
123+
def test_blocked_shift_one():
124+
q1 = cirq.NamedQubit('q1')
125+
q2 = cirq.NamedQubit('q2')
126+
before = cirq.Circuit(
127+
[
128+
cirq.Moment([cirq.H(q1), cirq.H(q2)]),
129+
cirq.Moment([cirq.measure(q1), cirq.Z(q2)]),
130+
cirq.Moment([cirq.H(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG)]),
131+
]
132+
)
133+
after = cirq.Circuit(
134+
[
135+
cirq.Moment([cirq.H(q1), cirq.H(q2)]),
136+
cirq.Moment([cirq.measure(q1), cirq.Z(q2)]),
137+
cirq.Moment([cirq.H(q1)]),
138+
cirq.Moment([cirq.measure(q2).with_tags(NO_COMPILE_TAG)]),
139+
]
140+
)
141+
assert_optimizes(before=before, after=after)
142+
assert_optimizes(before=before, after=before, with_context=True)
143+
144+
145+
def test_complex_move():
146+
q1 = cirq.NamedQubit('q1')
147+
q2 = cirq.NamedQubit('q2')
148+
q3 = cirq.NamedQubit('q3')
149+
before = cirq.Circuit(
150+
[
151+
cirq.Moment([cirq.H(q1), cirq.H(q2)]),
152+
cirq.Moment([cirq.measure(q1), cirq.Z(q2)]),
153+
cirq.Moment([cirq.H(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG)]),
154+
cirq.Moment([cirq.H(q3)]),
155+
cirq.Moment([cirq.X(q1), cirq.measure(q3).with_tags(NO_COMPILE_TAG)]),
156+
]
157+
)
158+
after = cirq.Circuit(
159+
[
160+
cirq.Moment([cirq.H(q1), cirq.H(q2)]),
161+
cirq.Moment([cirq.measure(q1), cirq.Z(q2)]),
162+
cirq.Moment([cirq.H(q1)]),
163+
cirq.Moment([cirq.H(q3)]),
164+
cirq.Moment([cirq.X(q1)]),
165+
cirq.Moment(
166+
[
167+
cirq.measure(q2).with_tags(NO_COMPILE_TAG),
168+
cirq.measure(q3).with_tags(NO_COMPILE_TAG),
169+
]
170+
),
171+
]
172+
)
173+
assert_optimizes(before=before, after=after)
174+
assert_optimizes(before=before, after=before, with_context=True)
175+
176+
177+
def test_complex_move_no_slide():
178+
q1 = cirq.NamedQubit('q1')
179+
q2 = cirq.NamedQubit('q2')
180+
q3 = cirq.NamedQubit('q3')
181+
before = cirq.Circuit(
182+
[
183+
cirq.Moment([cirq.H(q1), cirq.H(q2)]),
184+
cirq.Moment([cirq.measure(q1), cirq.Z(q2)]),
185+
cirq.Moment([cirq.H(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG)]),
186+
cirq.Moment([cirq.H(q3)]),
187+
cirq.Moment([cirq.X(q1), cirq.measure(q3)]),
188+
]
189+
)
190+
after = cirq.Circuit(
191+
[
192+
cirq.Moment(cirq.H(q1), cirq.H(q2)),
193+
cirq.Moment(cirq.measure(q1), cirq.Z(q2)),
194+
cirq.Moment(cirq.H(q1)),
195+
cirq.Moment(cirq.H(q3)),
196+
cirq.Moment(cirq.X(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG), cirq.measure(q3)),
197+
]
198+
)
199+
assert_optimizes(before=before, after=after, measure_only_moment=False)
200+
assert_optimizes(before=before, after=before, measure_only_moment=False, with_context=True)
201+
202+
203+
def test_multi_qubit():
204+
q0, q1 = cirq.LineQubit.range(2)
205+
circuit = cirq.Circuit(cirq.measure(q0, q1, key='m'), cirq.H(q1))
206+
assert_optimizes(before=circuit, after=circuit)
207+
208+
209+
def test_classically_controlled_op():
210+
q0, q1 = cirq.LineQubit.range(2)
211+
circuit = cirq.Circuit(
212+
cirq.H(q0), cirq.measure(q0, key='m'), cirq.X(q1).with_classical_controls('m')
213+
)
214+
assert_optimizes(before=circuit, after=circuit)

0 commit comments

Comments
 (0)