Skip to content

Commit f2c6f3c

Browse files
Add an optional CompilationTargetGateset postprocessor to contract the circuit (#6433)
1 parent 5dd05bf commit f2c6f3c

File tree

4 files changed

+219
-19
lines changed

4 files changed

+219
-19
lines changed

cirq-core/cirq/transformers/optimize_for_target_gateset.py

+39-17
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Transformers to rewrite a circuit using gates from a given target gateset."""
1616

17-
from typing import Optional, Callable, Hashable, Sequence, TYPE_CHECKING
17+
from typing import Optional, Callable, Hashable, Sequence, TYPE_CHECKING, Union
1818

1919
from cirq import circuits
2020
from cirq.protocols import decompose_protocol as dp
@@ -102,19 +102,29 @@ def optimize_for_target_gateset(
102102
context: Optional['cirq.TransformerContext'] = None,
103103
gateset: Optional['cirq.CompilationTargetGateset'] = None,
104104
ignore_failures: bool = True,
105+
max_num_passes: Union[int, None] = 1,
105106
) -> 'cirq.Circuit':
106107
"""Transforms the given circuit into an equivalent circuit using gates accepted by `gateset`.
107108
109+
Repeat max_num_passes times or when `max_num_passes=None` until no further changes can be done
108110
1. Run all `gateset.preprocess_transformers`
109111
2. Convert operations using built-in cirq decompose + `gateset.decompose_to_target_gateset`.
110112
3. Run all `gateset.postprocess_transformers`
111113
114+
Note:
115+
The optimizer is a heuristic and may not produce optimal results even with
116+
max_num_passes=None. The preprocessors and postprocessors of the gate set
117+
as well as their order yield different results.
118+
119+
112120
Args:
113121
circuit: Input circuit to transform. It will not be modified.
114122
context: `cirq.TransformerContext` storing common configurable options for transformers.
115123
gateset: Target gateset, which should be an instance of `cirq.CompilationTargetGateset`.
116124
ignore_failures: If set, operations that fail to convert are left unchanged. If not set,
117125
conversion failures raise a ValueError.
126+
max_num_passes: The maximum number of passes to do. A value of `None` means to keep
127+
iterating until no more changes happen to the number of moments or operations.
118128
119129
Returns:
120130
An equivalent circuit containing gates accepted by `gateset`.
@@ -126,20 +136,32 @@ def optimize_for_target_gateset(
126136
return _decompose_operations_to_target_gateset(
127137
circuit, context=context, ignore_failures=ignore_failures
128138
)
129-
130-
for transformer in gateset.preprocess_transformers:
131-
circuit = transformer(circuit, context=context)
132-
133-
circuit = _decompose_operations_to_target_gateset(
134-
circuit,
135-
context=context,
136-
gateset=gateset,
137-
decomposer=gateset.decompose_to_target_gateset,
138-
ignore_failures=ignore_failures,
139-
tags_to_decompose=(gateset._intermediate_result_tag,),
140-
)
141-
142-
for transformer in gateset.postprocess_transformers:
143-
circuit = transformer(circuit, context=context)
144-
139+
if isinstance(max_num_passes, int):
140+
_outerloop = lambda: range(max_num_passes)
141+
else:
142+
143+
def _outerloop():
144+
while True:
145+
yield 0
146+
147+
initial_num_moments, initial_num_ops = len(circuit), sum(1 for _ in circuit.all_operations())
148+
for _ in _outerloop():
149+
for transformer in gateset.preprocess_transformers:
150+
circuit = transformer(circuit, context=context)
151+
circuit = _decompose_operations_to_target_gateset(
152+
circuit,
153+
context=context,
154+
gateset=gateset,
155+
decomposer=gateset.decompose_to_target_gateset,
156+
ignore_failures=ignore_failures,
157+
tags_to_decompose=(gateset._intermediate_result_tag,),
158+
)
159+
for transformer in gateset.postprocess_transformers:
160+
circuit = transformer(circuit, context=context)
161+
162+
num_moments, num_ops = len(circuit), sum(1 for _ in circuit.all_operations())
163+
if (num_moments, num_ops) == (initial_num_moments, initial_num_ops):
164+
# Stop early. No further optimizations can be done.
165+
break
166+
initial_num_moments, initial_num_ops = num_moments, num_ops
145167
return circuit.unfreeze(copy=False)

cirq-core/cirq/transformers/optimize_for_target_gateset_test.py

+150
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Union
16+
1517
import cirq
1618
from cirq.protocols.decompose_protocol import DecomposeResult
1719
from cirq.transformers.optimize_for_target_gateset import _decompose_operations_to_target_gateset
@@ -243,3 +245,151 @@ def test_optimize_for_target_gateset_deep():
243245
1: ───#2───────────────────────────────────────────────────────────────────────────
244246
''',
245247
)
248+
249+
250+
@pytest.mark.parametrize('max_num_passes', [2, None])
251+
def test_optimize_for_target_gateset_multiple_passes(max_num_passes: Union[int, None]):
252+
gateset = cirq.CZTargetGateset()
253+
254+
input_circuit = cirq.Circuit(
255+
[
256+
cirq.Moment(
257+
cirq.X(cirq.LineQubit(1)),
258+
cirq.X(cirq.LineQubit(2)),
259+
cirq.X(cirq.LineQubit(3)),
260+
cirq.X(cirq.LineQubit(6)),
261+
),
262+
cirq.Moment(
263+
cirq.H(cirq.LineQubit(0)),
264+
cirq.H(cirq.LineQubit(1)),
265+
cirq.H(cirq.LineQubit(2)),
266+
cirq.H(cirq.LineQubit(3)),
267+
cirq.H(cirq.LineQubit(4)),
268+
cirq.H(cirq.LineQubit(5)),
269+
cirq.H(cirq.LineQubit(6)),
270+
),
271+
cirq.Moment(
272+
cirq.H(cirq.LineQubit(1)), cirq.H(cirq.LineQubit(3)), cirq.H(cirq.LineQubit(5))
273+
),
274+
cirq.Moment(
275+
cirq.CZ(cirq.LineQubit(0), cirq.LineQubit(1)),
276+
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(3)),
277+
cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(5)),
278+
),
279+
cirq.Moment(
280+
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(1)),
281+
cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(3)),
282+
cirq.CZ(cirq.LineQubit(6), cirq.LineQubit(5)),
283+
),
284+
]
285+
)
286+
desired_circuit = cirq.Circuit.from_moments(
287+
cirq.Moment(
288+
cirq.PhasedXZGate(axis_phase_exponent=0.5, x_exponent=-0.5, z_exponent=1.0).on(
289+
cirq.LineQubit(4)
290+
)
291+
),
292+
cirq.Moment(cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(5))),
293+
cirq.Moment(
294+
cirq.PhasedXZGate(axis_phase_exponent=-1.0, x_exponent=1, z_exponent=0).on(
295+
cirq.LineQubit(1)
296+
),
297+
cirq.PhasedXZGate(axis_phase_exponent=0.5, x_exponent=-0.5, z_exponent=1.0).on(
298+
cirq.LineQubit(0)
299+
),
300+
cirq.PhasedXZGate(axis_phase_exponent=-1.0, x_exponent=1, z_exponent=0).on(
301+
cirq.LineQubit(3)
302+
),
303+
cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=0.0).on(
304+
cirq.LineQubit(2)
305+
),
306+
),
307+
cirq.Moment(
308+
cirq.CZ(cirq.LineQubit(0), cirq.LineQubit(1)),
309+
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(3)),
310+
),
311+
cirq.Moment(
312+
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(1)),
313+
cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(3)),
314+
),
315+
cirq.Moment(
316+
cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=0.0).on(
317+
cirq.LineQubit(6)
318+
)
319+
),
320+
cirq.Moment(cirq.CZ(cirq.LineQubit(6), cirq.LineQubit(5))),
321+
)
322+
got = cirq.optimize_for_target_gateset(
323+
input_circuit, gateset=gateset, max_num_passes=max_num_passes
324+
)
325+
cirq.testing.assert_same_circuits(got, desired_circuit)
326+
327+
328+
@pytest.mark.parametrize('max_num_passes', [2, None])
329+
def test_optimize_for_target_gateset_multiple_passes_dont_preserve_moment_structure(
330+
max_num_passes: Union[int, None]
331+
):
332+
gateset = cirq.CZTargetGateset(preserve_moment_structure=False)
333+
334+
input_circuit = cirq.Circuit(
335+
[
336+
cirq.Moment(
337+
cirq.X(cirq.LineQubit(1)),
338+
cirq.X(cirq.LineQubit(2)),
339+
cirq.X(cirq.LineQubit(3)),
340+
cirq.X(cirq.LineQubit(6)),
341+
),
342+
cirq.Moment(
343+
cirq.H(cirq.LineQubit(0)),
344+
cirq.H(cirq.LineQubit(1)),
345+
cirq.H(cirq.LineQubit(2)),
346+
cirq.H(cirq.LineQubit(3)),
347+
cirq.H(cirq.LineQubit(4)),
348+
cirq.H(cirq.LineQubit(5)),
349+
cirq.H(cirq.LineQubit(6)),
350+
),
351+
cirq.Moment(
352+
cirq.H(cirq.LineQubit(1)), cirq.H(cirq.LineQubit(3)), cirq.H(cirq.LineQubit(5))
353+
),
354+
cirq.Moment(
355+
cirq.CZ(cirq.LineQubit(0), cirq.LineQubit(1)),
356+
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(3)),
357+
cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(5)),
358+
),
359+
cirq.Moment(
360+
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(1)),
361+
cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(3)),
362+
cirq.CZ(cirq.LineQubit(6), cirq.LineQubit(5)),
363+
),
364+
]
365+
)
366+
desired_circuit = cirq.Circuit(
367+
cirq.PhasedXZGate(axis_phase_exponent=0.5, x_exponent=-0.5, z_exponent=1.0).on(
368+
cirq.LineQubit(4)
369+
),
370+
cirq.PhasedXZGate(axis_phase_exponent=-1.0, x_exponent=1, z_exponent=0).on(
371+
cirq.LineQubit(1)
372+
),
373+
cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=0.0).on(
374+
cirq.LineQubit(2)
375+
),
376+
cirq.PhasedXZGate(axis_phase_exponent=0.5, x_exponent=-0.5, z_exponent=1.0).on(
377+
cirq.LineQubit(0)
378+
),
379+
cirq.PhasedXZGate(axis_phase_exponent=-1.0, x_exponent=1, z_exponent=0).on(
380+
cirq.LineQubit(3)
381+
),
382+
cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=0.0).on(
383+
cirq.LineQubit(6)
384+
),
385+
cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(5)),
386+
cirq.CZ(cirq.LineQubit(0), cirq.LineQubit(1)),
387+
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(3)),
388+
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(1)),
389+
cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(3)),
390+
cirq.CZ(cirq.LineQubit(6), cirq.LineQubit(5)),
391+
)
392+
got = cirq.optimize_for_target_gateset(
393+
input_circuit, gateset=gateset, max_num_passes=max_num_passes
394+
)
395+
cirq.testing.assert_same_circuits(got, desired_circuit)

cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Base class for creating custom target gatesets which can be used for compilation."""
1616

17-
from typing import Optional, List, Hashable, TYPE_CHECKING
17+
from typing import Optional, List, Hashable, TYPE_CHECKING, Union, Type
1818
import abc
1919

2020
from cirq import circuits, ops, protocols, transformers
@@ -80,6 +80,27 @@ class CompilationTargetGateset(ops.Gateset, metaclass=abc.ABCMeta):
8080
which can transform any given circuit to contain gates accepted by this gateset.
8181
"""
8282

83+
def __init__(
84+
self,
85+
*gates: Union[Type['cirq.Gate'], 'cirq.Gate', 'cirq.GateFamily'],
86+
name: Optional[str] = None,
87+
unroll_circuit_op: bool = True,
88+
preserve_moment_structure: bool = True,
89+
):
90+
"""Initializes CompilationTargetGateset.
91+
92+
Args:
93+
*gates: A list of `cirq.Gate` subclasses / `cirq.Gate` instances /
94+
`cirq.GateFamily` instances to initialize the Gateset.
95+
name: (Optional) Name for the Gateset. Useful for description.
96+
unroll_circuit_op: If True, `cirq.CircuitOperation` is recursively
97+
validated by validating the underlying `cirq.Circuit`.
98+
preserve_moment_structure: Whether to preserve the moment structure of the
99+
circuit during compilation or not.
100+
"""
101+
super().__init__(*gates, name=name, unroll_circuit_op=unroll_circuit_op)
102+
self._preserve_moment_structure = preserve_moment_structure
103+
83104
@property
84105
@abc.abstractmethod
85106
def num_qubits(self) -> int:
@@ -140,11 +161,14 @@ def preprocess_transformers(self) -> List['cirq.TRANSFORMER']:
140161
@property
141162
def postprocess_transformers(self) -> List['cirq.TRANSFORMER']:
142163
"""List of transformers which should be run after decomposing individual operations."""
143-
return [
164+
processors: List['cirq.TRANSFORMER'] = [
144165
merge_single_qubit_gates.merge_single_qubit_moments_to_phxz,
145166
transformers.drop_negligible_operations,
146167
transformers.drop_empty_moments,
147168
]
169+
if not self._preserve_moment_structure:
170+
processors.append(transformers.stratified_circuit)
171+
return processors
148172

149173

150174
class TwoQubitCompilationTargetGateset(CompilationTargetGateset):

cirq-core/cirq/transformers/target_gatesets/cz_gateset.py

+4
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
atol: float = 1e-8,
4949
allow_partial_czs: bool = False,
5050
additional_gates: Sequence[Union[Type['cirq.Gate'], 'cirq.Gate', 'cirq.GateFamily']] = (),
51+
preserve_moment_structure: bool = True,
5152
) -> None:
5253
"""Initializes CZTargetGateset
5354
@@ -57,6 +58,8 @@ def __init__(
5758
`cirq.CZ`, are part of this gateset.
5859
additional_gates: Sequence of additional gates / gate families which should also
5960
be "accepted" by this gateset. This is empty by default.
61+
preserve_moment_structure: Whether to preserve the moment structure of the
62+
circuit during compilation or not.
6063
"""
6164
super().__init__(
6265
ops.CZPowGate if allow_partial_czs else ops.CZ,
@@ -65,6 +68,7 @@ def __init__(
6568
ops.GlobalPhaseGate,
6669
*additional_gates,
6770
name='CZPowTargetGateset' if allow_partial_czs else 'CZTargetGateset',
71+
preserve_moment_structure=preserve_moment_structure,
6872
)
6973
self.additional_gates = tuple(
7074
g if isinstance(g, ops.GateFamily) else ops.GateFamily(gate=g) for g in additional_gates

0 commit comments

Comments
 (0)