Skip to content

Commit ef6bcf2

Browse files
committed
Add first draft of ZXTransformer to contrib.
This is a custom transformer which uses ZX-calculus through the PyZX library to perform circuit optimisation. See issue #6585.
1 parent bc4cd6d commit ef6bcf2

File tree

5 files changed

+302
-0
lines changed

5 files changed

+302
-0
lines changed

cirq-core/cirq/contrib/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@
2424
from cirq.contrib.qcircuit import circuit_to_latex_using_qcircuit
2525
from cirq.contrib import json
2626
from cirq.contrib.circuitdag import CircuitDag, Unique
27+
from cirq.contrib.zxtransformer import zx_transformer

cirq-core/cirq/contrib/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@ pylatex~=1.4
66
# quimb
77
quimb~=1.7
88
opt_einsum
9+
10+
# required for zxtransformer
11+
pyzx==0.8.0
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright 2024 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""A custom transformer for Cirq which uses ZX-Calculus for circuit optimization, implemented using
16+
PyZX."""
17+
18+
from cirq.contrib.zxtransformer.zxtransformer import zx_transformer
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Copyright 2024 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""A custom transformer for Cirq which uses ZX-Calculus for circuit optimization, implemented
16+
using PyZX."""
17+
18+
import functools
19+
from typing import List, Callable, Optional, Union
20+
21+
import cirq
22+
from cirq import circuits
23+
24+
import pyzx as zx
25+
from pyzx.circuit import gates as zx_gates
26+
from fractions import Fraction
27+
28+
29+
@functools.cache
30+
def _cirq_to_pyzx():
31+
return {
32+
cirq.H: zx_gates.HAD,
33+
cirq.CZ: zx_gates.CZ,
34+
cirq.CNOT: zx_gates.CNOT,
35+
cirq.SWAP: zx_gates.SWAP,
36+
cirq.CCZ: zx_gates.CCZ,
37+
}
38+
39+
40+
def cirq_gate_to_zx_gate(
41+
cirq_gate: Optional[cirq.Gate], qubits: List[int]
42+
) -> Optional[zx_gates.Gate]:
43+
"""Convert a Cirq gate to a PyZX gate."""
44+
if cirq.Gate is None:
45+
return None
46+
47+
if isinstance(cirq_gate, (cirq.Rx, cirq.XPowGate)):
48+
return zx_gates.XPhase(*qubits, phase=Fraction(cirq_gate.exponent).limit_denominator())
49+
if isinstance(cirq_gate, (cirq.Ry, cirq.YPowGate)):
50+
return zx_gates.YPhase(*qubits, phase=Fraction(cirq_gate.exponent).limit_denominator())
51+
if isinstance(cirq_gate, (cirq.Rz, cirq.ZPowGate)):
52+
return zx_gates.ZPhase(*qubits, phase=Fraction(cirq_gate.exponent).limit_denominator())
53+
54+
# TODO: Deal with exponents other than nice ones.
55+
if (gate := _cirq_to_pyzx().get(cirq_gate, None)) is not None:
56+
return gate(*qubits)
57+
58+
return None
59+
60+
61+
cirq_gate_table = {
62+
'rx': cirq.XPowGate,
63+
'ry': cirq.YPowGate,
64+
'rz': cirq.ZPowGate,
65+
'h': cirq.HPowGate,
66+
'cx': cirq.CXPowGate,
67+
'cz': cirq.CZPowGate,
68+
'swap': cirq.SwapPowGate,
69+
'ccz': cirq.CCZPowGate,
70+
}
71+
72+
73+
def _cirq_to_circuits_and_ops(
74+
circuit: circuits.AbstractCircuit, qubits: List[cirq.Qid]
75+
) -> List[Union[zx.Circuit, cirq.Operation]]:
76+
"""Convert an AbstractCircuit to a list of PyZX Circuits and cirq.Operations. As much of the
77+
AbstractCircuit is converted to PyZX as possible, but some gates are not supported by PyZX and
78+
are left as cirq.Operations.
79+
80+
:param circuit: The AbstractCircuit to convert.
81+
:return: A list of PyZX Circuits and cirq.Operations corresponding to the AbstractCircuit.
82+
"""
83+
circuits_and_ops: List[Union[zx.Circuit, cirq.Operation]] = []
84+
qubit_to_index = {qubit: index for index, qubit in enumerate(qubits)}
85+
current_circuit: Optional[zx.Circuit] = None
86+
for moment in circuit:
87+
for op in moment:
88+
gate_qubits = [qubit_to_index[qarg] for qarg in op.qubits]
89+
gate = cirq_gate_to_zx_gate(op.gate, gate_qubits)
90+
if not gate:
91+
# Encountered an operation not supported by PyZX, so just store it.
92+
# Flush the current PyZX Circuit first if there is one.
93+
if current_circuit is not None:
94+
circuits_and_ops.append(current_circuit)
95+
current_circuit = None
96+
circuits_and_ops.append(op)
97+
continue
98+
99+
if current_circuit is None:
100+
current_circuit = zx.Circuit(len(qubits))
101+
current_circuit.add_gate(gate)
102+
103+
# Flush any remaining PyZX Circuit.
104+
if current_circuit is not None:
105+
circuits_and_ops.append(current_circuit)
106+
107+
return circuits_and_ops
108+
109+
110+
def _recover_circuit(
111+
circuits_and_ops: List[Union[zx.Circuit, cirq.Operation]], qubits: List[cirq.Qid]
112+
) -> circuits.Circuit:
113+
"""Recovers a cirq.Circuit from a list of PyZX Circuits and cirq.Operations.
114+
115+
:param circuits_and_ops: The list of (optimized) PyZX Circuits and cirq.Operations from which to
116+
recover the cirq.Circuit.
117+
:return: An optimized version of the original input circuit to ZXTransformer.
118+
:raises ValueError: If an unsupported gate has been encountered.
119+
"""
120+
cirq_circuit = circuits.Circuit()
121+
for circuit_or_op in circuits_and_ops:
122+
if isinstance(circuit_or_op, cirq.Operation):
123+
cirq_circuit.append(circuit_or_op)
124+
continue
125+
for gate in circuit_or_op.gates:
126+
gate_name = (
127+
gate.qasm_name
128+
if not (hasattr(gate, 'adjoint') and gate.adjoint)
129+
else gate.qasm_name_adjoint
130+
)
131+
gate_type = cirq_gate_table[gate_name]
132+
if gate_type is None:
133+
raise ValueError(f"Unsupported gate: {gate_name}.")
134+
qargs: List[cirq.Qid] = []
135+
for attr in ['ctrl1', 'ctrl2', 'control', 'target']:
136+
if hasattr(gate, attr):
137+
qargs.append(qubits[getattr(gate, attr)])
138+
params: List[float] = []
139+
if hasattr(gate, 'phase'):
140+
params = [float(gate.phase)]
141+
elif hasattr(gate, 'phases'):
142+
params = [float(phase) for phase in gate.phases]
143+
elif gate_name in ('h', 'cz', 'cx', 'swap', 'ccz'):
144+
params = [1.0]
145+
cirq_circuit.append(gate_type(exponent=params[0])(*qargs))
146+
return cirq_circuit
147+
148+
149+
def _optimize(c: zx.Circuit) -> zx.Circuit:
150+
g = c.to_graph()
151+
zx.simplify.full_reduce(g)
152+
return zx.extract.extract_circuit(g)
153+
154+
155+
@cirq.transformer
156+
def zx_transformer(
157+
circuit: circuits.AbstractCircuit,
158+
context: Optional[cirq.TransformerContext] = None,
159+
optimizer: Callable[[zx.Circuit], zx.Circuit] = _optimize,
160+
) -> circuits.Circuit:
161+
"""Perform circuit optimization using pyzx.
162+
163+
Args:
164+
circuit: 'cirq.Circuit' input circuit to transform.
165+
context: `cirq.TransformerContext` storing common configurable
166+
options for transformers.
167+
optimizer: The optimization routine to execute. Defaults to `pyzx.simplify.full_reduce` if
168+
not specified.
169+
170+
Returns:
171+
The modified circuit after optimization.
172+
"""
173+
qubits: List[cirq.Qid] = [*circuit.all_qubits()]
174+
175+
circuits_and_ops = _cirq_to_circuits_and_ops(circuit, qubits)
176+
if not circuits_and_ops:
177+
copied_circuit = circuit.unfreeze(copy=True)
178+
return copied_circuit
179+
180+
circuits_and_ops = [
181+
optimizer(circuit) if isinstance(circuit, zx.Circuit) else circuit
182+
for circuit in circuits_and_ops
183+
]
184+
185+
return _recover_circuit(circuits_and_ops, qubits)
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright 2024 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for Cirq ZX transformer."""
16+
17+
from typing import Optional, Callable
18+
19+
import cirq
20+
import pyzx as zx
21+
22+
from cirq.contrib.zxtransformer.zxtransformer import zx_transformer, _cirq_to_circuits_and_ops
23+
24+
25+
def _run_zxtransformer(
26+
qc: cirq.Circuit, optimizer: Optional[Callable[[zx.Circuit], zx.Circuit]] = None
27+
) -> None:
28+
zx_qc = zx_transformer(qc) if optimizer is None else zx_transformer(qc, optimizer=optimizer)
29+
qubit_map = {qid: qid for qid in qc.all_qubits()}
30+
cirq.testing.assert_circuits_have_same_unitary_given_final_permutation(qc, zx_qc, qubit_map)
31+
32+
33+
def test_basic_circuit() -> None:
34+
"""Test a basic circuit.
35+
36+
Taken from https://github.com/Quantomatic/pyzx/blob/master/circuits/Fast/mod5_4_before
37+
"""
38+
q = cirq.LineQubit.range(5)
39+
circuit = cirq.Circuit(
40+
cirq.X(q[4]),
41+
cirq.H(q[4]),
42+
cirq.CCZ(q[0], q[3], q[4]),
43+
cirq.CCZ(q[2], q[3], q[4]),
44+
cirq.H(q[4]),
45+
cirq.CX(q[3], q[4]),
46+
cirq.H(q[4]),
47+
cirq.CCZ(q[1], q[2], q[4]),
48+
cirq.H(q[4]),
49+
cirq.CX(q[2], q[4]),
50+
cirq.H(q[4]),
51+
cirq.CCZ(q[0], q[1], q[4]),
52+
cirq.H(q[4]),
53+
cirq.CX(q[1], q[4]),
54+
cirq.CX(q[0], q[4]),
55+
)
56+
57+
_run_zxtransformer(circuit)
58+
59+
60+
def test_custom_optimize() -> None:
61+
"""Test custom optimize method."""
62+
q = cirq.LineQubit.range(4)
63+
circuit = cirq.Circuit(
64+
cirq.H(q[0]),
65+
cirq.H(q[1]),
66+
cirq.H(q[2]),
67+
cirq.H(q[3]),
68+
cirq.CX(q[0], q[1]),
69+
cirq.CX(q[1], q[2]),
70+
cirq.CX(q[2], q[3]),
71+
cirq.CX(q[3], q[0]),
72+
)
73+
74+
def optimize(circ: zx.Circuit) -> zx.Circuit:
75+
# Any function that takes a zx.Circuit and returns a zx.Circuit will do.
76+
return circ.to_basic_gates()
77+
78+
_run_zxtransformer(circuit, optimize)
79+
80+
81+
def test_measurement() -> None:
82+
"""Test a circuit with a measurement."""
83+
q = cirq.NamedQubit("q")
84+
circuit = cirq.Circuit(cirq.H(q), cirq.measure(q, key='c'), cirq.H(q))
85+
circuits_and_ops = _cirq_to_circuits_and_ops(circuit, [*circuit.all_qubits()])
86+
assert len(circuits_and_ops) == 3
87+
assert circuits_and_ops[1] == cirq.measure(q, key='c')
88+
89+
90+
def test_conditional_gate() -> None:
91+
"""Test a circuit with a conditional gate."""
92+
q = cirq.NamedQubit("q")
93+
circuit = cirq.Circuit(cirq.X(q), cirq.H(q).with_classical_controls('c'), cirq.X(q))
94+
circuits_and_ops = _cirq_to_circuits_and_ops(circuit, [*circuit.all_qubits()])
95+
assert len(circuits_and_ops) == 3

0 commit comments

Comments
 (0)