Skip to content

Commit 6ec1f36

Browse files
dstrain115rht
authored andcommitted
Update Neutral Atoms to Transformers (quantumlib#5311)
* Update Neutral Atoms to Transformers Create NeutralAtomGateset that works as a compilation target gateset. Point ConvertToNeutralAtomGateset towards optimize_for_target_gateset instead.
1 parent 8d0616c commit 6ec1f36

8 files changed

+240
-65
lines changed

cirq-core/cirq/neutral_atoms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,5 @@
2121
is_native_neutral_atom_gate,
2222
is_native_neutral_atom_op,
2323
)
24+
25+
from cirq.neutral_atoms.neutral_atom_gateset import NeutralAtomGateset

cirq-core/cirq/neutral_atoms/convert_to_neutral_atom_gates.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,19 @@
1313
# limitations under the License.
1414
from typing import List, Optional, TYPE_CHECKING
1515

16-
from cirq import ops, protocols
16+
from cirq import ops, protocols, transformers
17+
from cirq._compat import deprecated_class
1718
from cirq.circuits.optimization_pass import PointOptimizationSummary, PointOptimizer
18-
from cirq.neutral_atoms import neutral_atom_devices
19-
from cirq import transformers
19+
from cirq.neutral_atoms.neutral_atom_gateset import NeutralAtomGateset
2020

2121
if TYPE_CHECKING:
2222
import cirq
2323

2424

25+
@deprecated_class(
26+
deadline='v0.16',
27+
fix='Use cirq.optimize_for_target_gateset(circuit, gateset=NeutralAtomGateset()).',
28+
)
2529
class ConvertToNeutralAtomGates(PointOptimizer):
2630
"""Attempts to convert gates into native Atom gates.
2731
@@ -48,7 +52,7 @@ def __init__(self, ignore_failures=False) -> None:
4852
"""
4953
super().__init__()
5054
self.ignore_failures = ignore_failures
51-
self.gateset = neutral_atom_devices.neutral_atom_gateset()
55+
self.gateset = NeutralAtomGateset()
5256

5357
def _convert_one(self, op: ops.Operation) -> ops.OP_TREE:
5458
# Known matrix?
@@ -91,8 +95,8 @@ def optimization_at(
9195

9296

9397
def is_native_neutral_atom_op(operation: ops.Operation) -> bool:
94-
return operation in neutral_atom_devices.neutral_atom_gateset()
98+
return operation in NeutralAtomGateset()
9599

96100

97101
def is_native_neutral_atom_gate(gate: ops.Gate) -> bool:
98-
return gate in neutral_atom_devices.neutral_atom_gateset()
102+
return gate in NeutralAtomGateset()

cirq-core/cirq/neutral_atoms/convert_to_neutral_atom_gates_test.py

Lines changed: 65 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,37 +16,58 @@
1616
import pytest
1717

1818
import cirq
19-
from cirq import ops
19+
20+
21+
Q = cirq.LineQubit.range(3)
22+
23+
24+
@pytest.mark.parametrize(
25+
'expected',
26+
(
27+
cirq.Circuit(cirq.X.on(Q[0])),
28+
cirq.Circuit(cirq.Y.on(Q[0])),
29+
cirq.Circuit(cirq.ParallelGate(cirq.X, 3).on(*Q)),
30+
cirq.Circuit(cirq.CNOT.on(Q[0], Q[1])),
31+
),
32+
)
33+
def test_gates_preserved(expected: cirq.Circuit):
34+
actual = cirq.optimize_for_target_gateset(
35+
expected, gateset=cirq.neutral_atoms.NeutralAtomGateset()
36+
)
37+
assert actual == expected
2038

2139

2240
def test_coverage():
23-
q = cirq.LineQubit.range(3)
24-
g = cirq.testing.ThreeQubitGate()
25-
26-
class FakeOperation(ops.Operation):
27-
def __init__(self, gate, qubits):
28-
self._gate = gate
29-
self._qubits = qubits
30-
31-
@property
32-
def qubits(self):
33-
return self._qubits
34-
35-
def with_qubits(self, *new_qubits):
36-
return FakeOperation(self._gate, new_qubits)
37-
38-
op = FakeOperation(g, q).with_qubits(*q)
39-
circuit_ops = [cirq.Y(q[0]), cirq.ParallelGate(cirq.X, 3).on(*q)]
40-
c = cirq.Circuit(circuit_ops)
41-
cirq.neutral_atoms.ConvertToNeutralAtomGates().optimize_circuit(c)
42-
assert c == cirq.Circuit(circuit_ops)
43-
assert cirq.neutral_atoms.ConvertToNeutralAtomGates().convert(cirq.X.on(q[0])) == [
44-
cirq.X.on(q[0])
45-
]
46-
with pytest.raises(TypeError, match="Don't know how to work with"):
47-
cirq.neutral_atoms.ConvertToNeutralAtomGates().convert(op)
48-
assert not cirq.neutral_atoms.is_native_neutral_atom_op(op)
49-
assert not cirq.neutral_atoms.is_native_neutral_atom_gate(g)
41+
with cirq.testing.assert_deprecated(
42+
"Use cirq.optimize_for_target_gateset", deadline='v0.16', count=5
43+
):
44+
q = cirq.LineQubit.range(3)
45+
g = cirq.testing.ThreeQubitGate()
46+
47+
class FakeOperation(cirq.Operation):
48+
def __init__(self, gate, qubits):
49+
self._gate = gate
50+
self._qubits = qubits
51+
52+
@property
53+
def qubits(self):
54+
return self._qubits
55+
56+
def with_qubits(self, *new_qubits):
57+
return FakeOperation(self._gate, new_qubits)
58+
59+
op = FakeOperation(g, q).with_qubits(*q)
60+
circuit_ops = [cirq.Y(q[0]), cirq.ParallelGate(cirq.X, 3).on(*q)]
61+
c = cirq.Circuit(circuit_ops)
62+
cirq.neutral_atoms.ConvertToNeutralAtomGates().optimize_circuit(c)
63+
assert c == cirq.Circuit(circuit_ops)
64+
assert cirq.neutral_atoms.ConvertToNeutralAtomGates().convert(cirq.X.on(q[0])) == [
65+
cirq.X.on(q[0])
66+
]
67+
with pytest.raises(TypeError, match="Don't know how to work with"):
68+
cirq.neutral_atoms.ConvertToNeutralAtomGates().convert(op)
69+
assert not cirq.neutral_atoms.is_native_neutral_atom_op(op)
70+
assert not cirq.neutral_atoms.is_native_neutral_atom_gate(g)
5071

5172

5273
def test_avoids_decompose_fallback_when_matrix_available_single_qubit():
@@ -60,8 +81,13 @@ def _decompose_(self, qubits):
6081

6182
q = cirq.GridQubit(0, 0)
6283
c = cirq.Circuit(OtherX().on(q), OtherOtherX().on(q))
63-
cirq.neutral_atoms.ConvertToNeutralAtomGates().optimize_circuit(c)
64-
cirq.testing.assert_has_diagram(c, '(0, 0): ───PhX(1)───PhX(1)───')
84+
converted = cirq.optimize_for_target_gateset(c, gateset=cirq.neutral_atoms.NeutralAtomGateset())
85+
cirq.testing.assert_has_diagram(converted, '(0, 0): ───PhX(1)───PhX(1)───')
86+
with cirq.testing.assert_deprecated(
87+
"Use cirq.optimize_for_target_gateset", deadline='v0.16', count=2
88+
):
89+
cirq.neutral_atoms.ConvertToNeutralAtomGates().optimize_circuit(c)
90+
cirq.testing.assert_has_diagram(c, '(0, 0): ───PhX(1)───PhX(1)───')
6591

6692

6793
def test_avoids_decompose_fallback_when_matrix_available_two_qubit():
@@ -76,12 +102,15 @@ def _decompose_(self, qubits):
76102
q00 = cirq.GridQubit(0, 0)
77103
q01 = cirq.GridQubit(0, 1)
78104
c = cirq.Circuit(OtherCZ().on(q00, q01), OtherOtherCZ().on(q00, q01))
79-
cirq.neutral_atoms.ConvertToNeutralAtomGates().optimize_circuit(c)
80-
cirq.testing.assert_has_diagram(
81-
c,
82-
"""
105+
expected_diagram = """
83106
(0, 0): ───@───@───
84107
│ │
85108
(0, 1): ───@───@───
86-
""",
87-
)
109+
"""
110+
converted = cirq.optimize_for_target_gateset(c, gateset=cirq.neutral_atoms.NeutralAtomGateset())
111+
cirq.testing.assert_has_diagram(converted, expected_diagram)
112+
with cirq.testing.assert_deprecated(
113+
"Use cirq.optimize_for_target_gateset", deadline='v0.16', count=2
114+
):
115+
cirq.neutral_atoms.ConvertToNeutralAtomGates().optimize_circuit(c)
116+
cirq.testing.assert_has_diagram(c, expected_diagram)

cirq-core/cirq/neutral_atoms/neutral_atom_devices.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from cirq.devices.grid_qubit import GridQubit
2121
from cirq.ops import raw_types
2222
from cirq.value import Duration
23-
from cirq.neutral_atoms import convert_to_neutral_atom_gates
23+
from cirq.neutral_atoms.convert_to_neutral_atom_gates import ConvertToNeutralAtomGates
24+
from cirq.neutral_atoms.neutral_atom_gateset import NeutralAtomGateset
2425

2526
if TYPE_CHECKING:
2627
import cirq
@@ -31,22 +32,6 @@ def _subgate_if_parallel_gate(gate: 'cirq.Gate') -> 'cirq.Gate':
3132
return gate.sub_gate if isinstance(gate, ops.ParallelGate) else gate
3233

3334

34-
def neutral_atom_gateset(max_parallel_z=None, max_parallel_xy=None):
35-
return ops.Gateset(
36-
ops.AnyIntegerPowerGateFamily(ops.CNotPowGate),
37-
ops.AnyIntegerPowerGateFamily(ops.CCNotPowGate),
38-
ops.AnyIntegerPowerGateFamily(ops.CZPowGate),
39-
ops.AnyIntegerPowerGateFamily(ops.CCZPowGate),
40-
ops.ParallelGateFamily(ops.ZPowGate, max_parallel_allowed=max_parallel_z),
41-
ops.ParallelGateFamily(ops.XPowGate, max_parallel_allowed=max_parallel_xy),
42-
ops.ParallelGateFamily(ops.YPowGate, max_parallel_allowed=max_parallel_xy),
43-
ops.ParallelGateFamily(ops.PhasedXPowGate, max_parallel_allowed=max_parallel_xy),
44-
ops.MeasurementGate,
45-
ops.IdentityGate,
46-
unroll_circuit_op=False,
47-
)
48-
49-
5035
@value.value_equality
5136
class NeutralAtomDevice(devices.Device):
5237
"""A device with qubits placed on a grid."""
@@ -107,7 +92,7 @@ def __init__(
10792
ops.AnyIntegerPowerGateFamily(ops.CCZPowGate),
10893
unroll_circuit_op=False,
10994
)
110-
self.gateset = neutral_atom_gateset(max_parallel_z, max_parallel_xy)
95+
self.gateset = NeutralAtomGateset(max_parallel_z, max_parallel_xy)
11196
for q in qubits:
11297
if not isinstance(q, GridQubit):
11398
raise ValueError(f'Unsupported qubit type: {q!r}')
@@ -133,7 +118,7 @@ def qubit_list(self):
133118
deadline='v0.15',
134119
)
135120
def decompose_operation(self, operation: ops.Operation) -> ops.OP_TREE:
136-
return convert_to_neutral_atom_gates.ConvertToNeutralAtomGates().convert(operation)
121+
return ConvertToNeutralAtomGates().convert(operation)
137122

138123
def duration_of(self, operation: ops.Operation):
139124
"""Provides the duration of the given operation on this device.

cirq-core/cirq/neutral_atoms/neutral_atom_devices_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def test_init_errors():
114114

115115
def test_decompose_error_deprecated():
116116
d = square_device(2, 2, holes=[cirq.GridQubit(1, 1)])
117-
with cirq.testing.assert_deprecated('ConvertToNeutralAtomGates', deadline='v0.15'):
117+
with cirq.testing.assert_deprecated('ConvertToNeutralAtomGates', deadline='v0.15', count=2):
118118
for op in d.decompose_operation((cirq.CCZ**1.5).on(*(d.qubit_list()))):
119119
d.validate_operation(op)
120120

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2022 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import List, Optional, TYPE_CHECKING
15+
16+
from cirq import ops, protocols, transformers
17+
from cirq.protocols.decompose_protocol import DecomposeResult
18+
19+
20+
if TYPE_CHECKING:
21+
import cirq
22+
23+
24+
class NeutralAtomGateset(transformers.CompilationTargetGateset):
25+
"""A Compilation target intended for neutral atom devices.
26+
27+
This gateset supports CNOT, CCNOT (TOFFOLI) gates, CZ,
28+
CCZ gates, as well as single qubits gates that can be used
29+
in a parallel fashion. The maximum amount of parallelism
30+
can be set by arguments.
31+
32+
This compilation gateset decomposes operations into CZ
33+
because CZ gates are the highest fidelity two qubit gates
34+
for neutral atoms.
35+
36+
Args:
37+
max_parallel_z: The maximum amount of parallelism for
38+
Z gates.
39+
max_parallel_xy: The maximum amount of parallelism for
40+
X, Y and PhasedXPow gates.
41+
"""
42+
43+
def __init__(self, max_parallel_z: Optional[int] = None, max_parallel_xy: Optional[int] = None):
44+
super().__init__(
45+
ops.AnyIntegerPowerGateFamily(ops.CNotPowGate),
46+
ops.AnyIntegerPowerGateFamily(ops.CCNotPowGate),
47+
ops.AnyIntegerPowerGateFamily(ops.CZPowGate),
48+
ops.AnyIntegerPowerGateFamily(ops.CCZPowGate),
49+
ops.ParallelGateFamily(ops.ZPowGate, max_parallel_allowed=max_parallel_z),
50+
ops.ParallelGateFamily(ops.XPowGate, max_parallel_allowed=max_parallel_xy),
51+
ops.ParallelGateFamily(ops.YPowGate, max_parallel_allowed=max_parallel_xy),
52+
ops.ParallelGateFamily(ops.PhasedXPowGate, max_parallel_allowed=max_parallel_xy),
53+
ops.MeasurementGate,
54+
ops.IdentityGate,
55+
unroll_circuit_op=False,
56+
)
57+
58+
def num_qubits(self) -> int:
59+
"""Maximum number of qubits on which a gate from this gateset can act upon."""
60+
return 2
61+
62+
def decompose_to_target_gateset(self, op: 'cirq.Operation', moment_idx: int) -> DecomposeResult:
63+
"""Method to rewrite the given operation using gates from this gateset.
64+
Args:
65+
op: `cirq.Operation` to be rewritten using gates from this gateset.
66+
moment_idx: Moment index where the given operation `op` occurs in a circuit.
67+
Returns:
68+
- An equivalent `cirq.OP_TREE` implementing `op` using gates from this gateset.
69+
- `None` or `NotImplemented` if does not know how to decompose `op`.
70+
"""
71+
# Known matrix?
72+
mat = protocols.unitary(op, None) if len(op.qubits) <= 2 else None
73+
if mat is not None and len(op.qubits) == 1:
74+
gates = transformers.single_qubit_matrix_to_phased_x_z(mat)
75+
return [g.on(op.qubits[0]) for g in gates]
76+
if mat is not None and len(op.qubits) == 2:
77+
return transformers.two_qubit_matrix_to_cz_operations(
78+
op.qubits[0], op.qubits[1], mat, allow_partial_czs=False, clean_operations=True
79+
)
80+
81+
return NotImplemented
82+
83+
@property
84+
def preprocess_transformers(self) -> List['cirq.TRANSFORMER']:
85+
return []
86+
87+
@property
88+
def postprocess_transformers(self) -> List['cirq.TRANSFORMER']:
89+
return []
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2022 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytest
15+
16+
import cirq
17+
import cirq.neutral_atoms.neutral_atom_gateset as nag
18+
19+
Q = cirq.LineQubit.range(10)
20+
21+
22+
@pytest.mark.parametrize(
23+
'op,is_in_gateset',
24+
(
25+
(cirq.X.on(Q[0]), True),
26+
(cirq.Y.on(Q[0]), True),
27+
(cirq.I.on(Q[0]), True),
28+
(cirq.measure(*Q), True),
29+
(cirq.ParallelGate(cirq.X, 3).on(*Q[:3]), True),
30+
(cirq.ParallelGate(cirq.X, 4).on(*Q[:4]), False),
31+
(cirq.ParallelGate(cirq.X, 10).on(*Q), False),
32+
(cirq.ParallelGate(cirq.Y, 3).on(*Q[:3]), True),
33+
(cirq.ParallelGate(cirq.Y, 4).on(*Q[:4]), False),
34+
(cirq.ParallelGate(cirq.Y, 10).on(*Q), False),
35+
(cirq.ParallelGate(cirq.Z, 3).on(*Q[:3]), True),
36+
(cirq.ParallelGate(cirq.Z, 4).on(*Q[:4]), True),
37+
(cirq.ParallelGate(cirq.Z, 5).on(*Q[:5]), False),
38+
(cirq.ParallelGate(cirq.Z, 10).on(*Q), False),
39+
(
40+
cirq.ParallelGate(cirq.PhasedXPowGate(exponent=0.5, phase_exponent=0.25), 3).on(*Q[:3]),
41+
True,
42+
),
43+
(
44+
cirq.ParallelGate(cirq.PhasedXPowGate(exponent=0.5, phase_exponent=0.25), 4).on(*Q[:4]),
45+
False,
46+
),
47+
(cirq.CNOT.on(Q[0], Q[1]), True),
48+
((cirq.CNOT**0.5).on(Q[0], Q[1]), False),
49+
(cirq.CZ.on(Q[0], Q[1]), True),
50+
((cirq.CZ**0.5).on(Q[0], Q[1]), False),
51+
(cirq.CCZ.on(Q[0], Q[1], Q[2]), True),
52+
((cirq.CCZ**0.5).on(Q[0], Q[1], Q[2]), False),
53+
((cirq.TOFFOLI**0.5).on(Q[0], Q[1], Q[2]), False),
54+
),
55+
)
56+
def test_gateset(op: cirq.Operation, is_in_gateset: bool):
57+
gateset = nag.NeutralAtomGateset(max_parallel_z=4, max_parallel_xy=3)
58+
assert gateset.validate(op) == is_in_gateset
59+
converted_ops = cirq.optimize_for_target_gateset(cirq.Circuit(op), gateset=gateset)
60+
if is_in_gateset:
61+
assert converted_ops == cirq.Circuit(op)
62+
assert gateset.validate(converted_ops)
63+
64+
65+
def test_gateset_qubits():
66+
assert nag.NeutralAtomGateset(max_parallel_z=4, max_parallel_xy=3).num_qubits() == 2

0 commit comments

Comments
 (0)