Skip to content

Commit a00783e

Browse files
committed
CompilationTargetGateset support in GridDeviceMetadata
1 parent 5ec8774 commit a00783e

File tree

4 files changed

+506
-149
lines changed

4 files changed

+506
-149
lines changed

cirq-core/cirq/devices/grid_device_metadata.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Optional,
1919
FrozenSet,
2020
Iterable,
21+
Sequence,
2122
Tuple,
2223
Dict,
2324
)
@@ -40,6 +41,7 @@ def __init__(
4041
gateset: 'cirq.Gateset',
4142
gate_durations: Optional[Dict['cirq.GateFamily', 'cirq.Duration']] = None,
4243
all_qubits: Optional[Iterable['cirq.Qid']] = None,
44+
compilation_target_gatesets: Optional[Iterable['cirq.CompilationTargetGateset']] = None,
4345
):
4446
"""Create a GridDeviceMetadata object.
4547
@@ -60,6 +62,10 @@ def __init__(
6062
all_qubits: Optional iterable specifying all qubits
6163
found on the device. If None, all_qubits will
6264
be inferred from the entries in qubit_pairs.
65+
compilation_target_gatesets: A collection of valid
66+
`cirq.CompilationTargetGateset`s which can be used to
67+
transform circuits into ones that consist of only
68+
operations in `gateset`.
6369
6470
Raises:
6571
ValueError: if the union of GateFamily keys in gate_durations
@@ -102,6 +108,9 @@ def __init__(
102108
self._qubit_pairs = frozenset(edge_set)
103109
self._gateset = gateset
104110
self._isolated_qubits = all_qubits.difference(node_set)
111+
self._compilation_target_gatesets = (
112+
() if compilation_target_gatesets is None else tuple(compilation_target_gatesets)
113+
)
105114

106115
if gate_durations is not None:
107116
working_gatefamilies = frozenset(gate_durations.keys())
@@ -131,6 +140,11 @@ def gateset(self) -> 'cirq.Gateset':
131140
"""Returns the `cirq.Gateset` of supported gates on this device."""
132141
return self._gateset
133142

143+
@property
144+
def compilation_target_gatesets(self) -> Tuple['cirq.CompilationTargetGateset', ...]:
145+
"""Returns a sequence of valid `cirq.CompilationTargetGateset`s for this device."""
146+
return self._compilation_target_gatesets
147+
134148
@property
135149
def gate_durations(self) -> Optional[Dict['cirq.GateFamily', 'cirq.Duration']]:
136150
"""Get a dictionary mapping from gateset to duration for gates."""
@@ -146,27 +160,46 @@ def _value_equality_values_(self):
146160
self._gateset,
147161
tuple(duration_equality),
148162
tuple(sorted(self.qubit_set)),
163+
frozenset(self._compilation_target_gatesets),
149164
)
150165

151166
def __repr__(self) -> str:
152167
return (
153168
f'cirq.GridDeviceMetadata({repr(self._qubit_pairs)},'
154169
f' {repr(self._gateset)}, {repr(self._gate_durations)},'
155-
f' {repr(self.qubit_set)})'
170+
f' {repr(self.qubit_set)}, {repr(self._compilation_target_gatesets)})'
156171
)
157172

158173
def _json_dict_(self):
159174
duration_payload = None
160175
if self._gate_durations is not None:
161176
duration_payload = sorted(self._gate_durations.items(), key=lambda x: repr(x[0]))
162177

163-
return {
178+
jd = {
164179
'qubit_pairs': sorted(list(self._qubit_pairs)),
165180
'gateset': self._gateset,
166181
'gate_durations': duration_payload,
167182
'all_qubits': sorted(list(self.qubit_set)),
168183
}
184+
if len(self._compilation_target_gatesets) > 0:
185+
jd['compilation_target_gatesets'] = list(self._compilation_target_gatesets)
186+
187+
return jd
169188

170189
@classmethod
171-
def _from_json_dict_(cls, qubit_pairs, gateset, gate_durations, all_qubits, **kwargs):
172-
return cls(qubit_pairs, gateset, dict(gate_durations), all_qubits)
190+
def _from_json_dict_(
191+
cls,
192+
qubit_pairs,
193+
gateset,
194+
gate_durations,
195+
all_qubits,
196+
compilation_target_gatesets=None,
197+
**kwargs,
198+
):
199+
return cls(
200+
qubit_pairs,
201+
gateset,
202+
None if gate_durations is None else dict(gate_durations),
203+
all_qubits,
204+
compilation_target_gatesets,
205+
)

cirq-core/cirq/devices/grid_device_metadata_test.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,13 @@ def test_griddevice_metadata():
2323
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
2424
isolated_qubits = [cirq.GridQubit(9, 9), cirq.GridQubit(10, 10)]
2525
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZ)
26-
metadata = cirq.GridDeviceMetadata(qubit_pairs, gateset, all_qubits=qubits + isolated_qubits)
26+
target_gatesets = (cirq.CZTargetGateset(),)
27+
metadata = cirq.GridDeviceMetadata(
28+
qubit_pairs,
29+
gateset,
30+
all_qubits=qubits + isolated_qubits,
31+
compilation_target_gatesets=target_gatesets,
32+
)
2733
expected_pairings = frozenset(
2834
{
2935
(cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)),
@@ -45,6 +51,7 @@ def test_griddevice_metadata():
4551
assert metadata.nx_graph.nodes() == expected_graph.nodes()
4652
assert metadata.gate_durations is None
4753
assert metadata.isolated_qubits == frozenset(isolated_qubits)
54+
assert metadata.compilation_target_gatesets == target_gatesets
4855

4956

5057
def test_griddevice_metadata_bad_durations():
@@ -80,35 +87,58 @@ def test_griddevice_self_loop():
8087
def test_griddevice_json_load():
8188
qubits = cirq.GridQubit.rect(2, 3)
8289
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
83-
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate)
90+
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZ)
8491
duration = {
8592
cirq.GateFamily(cirq.XPowGate): cirq.Duration(nanos=1),
8693
cirq.GateFamily(cirq.YPowGate): cirq.Duration(picos=2),
8794
cirq.GateFamily(cirq.ZPowGate): cirq.Duration(picos=3),
95+
cirq.GateFamily(cirq.CZ): cirq.Duration(nanos=4),
8896
}
8997
isolated_qubits = [cirq.GridQubit(9, 9), cirq.GridQubit(10, 10)]
98+
target_gatesets = [cirq.CZTargetGateset()]
9099
metadata = cirq.GridDeviceMetadata(
91-
qubit_pairs, gateset, gate_durations=duration, all_qubits=qubits + isolated_qubits
100+
qubit_pairs,
101+
gateset,
102+
gate_durations=duration,
103+
all_qubits=qubits + isolated_qubits,
104+
compilation_target_gatesets=target_gatesets,
92105
)
93106
rep_str = cirq.to_json(metadata)
94107
assert metadata == cirq.read_json(json_text=rep_str)
95108

96109

110+
def test_griddevice_json_load_with_defaults():
111+
qubits = cirq.GridQubit.rect(2, 3)
112+
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
113+
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZ)
114+
115+
# Don't set parameters with default values
116+
metadata = cirq.GridDeviceMetadata(qubit_pairs, gateset)
117+
rep_str = cirq.to_json(metadata)
118+
119+
assert metadata == cirq.read_json(json_text=rep_str)
120+
121+
97122
def test_griddevice_metadata_equality():
98123
qubits = cirq.GridQubit.rect(2, 3)
99124
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
100-
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate)
125+
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZ, cirq.SQRT_ISWAP)
101126
duration = {
102127
cirq.GateFamily(cirq.XPowGate): cirq.Duration(nanos=1),
103128
cirq.GateFamily(cirq.YPowGate): cirq.Duration(picos=3),
104129
cirq.GateFamily(cirq.ZPowGate): cirq.Duration(picos=2),
130+
cirq.GateFamily(cirq.CZ): cirq.Duration(nanos=4),
131+
cirq.GateFamily(cirq.SQRT_ISWAP): cirq.Duration(nanos=5),
105132
}
106133
duration2 = {
107134
cirq.GateFamily(cirq.XPowGate): cirq.Duration(nanos=10),
108135
cirq.GateFamily(cirq.YPowGate): cirq.Duration(picos=13),
109136
cirq.GateFamily(cirq.ZPowGate): cirq.Duration(picos=12),
137+
cirq.GateFamily(cirq.CZ): cirq.Duration(nanos=14),
138+
cirq.GateFamily(cirq.SQRT_ISWAP): cirq.Duration(nanos=15),
110139
}
111140
isolated_qubits = [cirq.GridQubit(9, 9)]
141+
target_gatesets = [cirq.CZTargetGateset(), cirq.SqrtIswapTargetGateset()]
112142
metadata = cirq.GridDeviceMetadata(qubit_pairs, gateset, gate_durations=duration)
113143
metadata2 = cirq.GridDeviceMetadata(qubit_pairs[:2], gateset, gate_durations=duration)
114144
metadata3 = cirq.GridDeviceMetadata(qubit_pairs, gateset, gate_durations=None)
@@ -117,28 +147,47 @@ def test_griddevice_metadata_equality():
117147
metadata6 = cirq.GridDeviceMetadata(
118148
qubit_pairs, gateset, gate_durations=duration, all_qubits=qubits + isolated_qubits
119149
)
150+
metadata7 = cirq.GridDeviceMetadata(
151+
qubit_pairs, gateset, compilation_target_gatesets=target_gatesets
152+
)
153+
metadata8 = cirq.GridDeviceMetadata(
154+
qubit_pairs, gateset, compilation_target_gatesets=target_gatesets[::-1]
155+
)
156+
metadata9 = cirq.GridDeviceMetadata(
157+
qubit_pairs, gateset, compilation_target_gatesets=tuple(target_gatesets)
158+
)
159+
metadata10 = cirq.GridDeviceMetadata(
160+
qubit_pairs, gateset, compilation_target_gatesets=set(target_gatesets)
161+
)
120162

121163
eq = cirq.testing.EqualsTester()
122164
eq.add_equality_group(metadata)
123165
eq.add_equality_group(metadata2)
124166
eq.add_equality_group(metadata3)
125167
eq.add_equality_group(metadata4)
126168
eq.add_equality_group(metadata6)
169+
eq.add_equality_group(metadata7, metadata8, metadata9, metadata10)
127170

128171
assert metadata == metadata5
129172

130173

131174
def test_repr():
132175
qubits = cirq.GridQubit.rect(2, 3)
133176
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
134-
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate)
177+
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZ)
135178
duration = {
136179
cirq.GateFamily(cirq.XPowGate): cirq.Duration(nanos=1),
137180
cirq.GateFamily(cirq.YPowGate): cirq.Duration(picos=3),
138181
cirq.GateFamily(cirq.ZPowGate): cirq.Duration(picos=2),
182+
cirq.GateFamily(cirq.CZ): cirq.Duration(nanos=4),
139183
}
140184
isolated_qubits = [cirq.GridQubit(9, 9)]
185+
target_gatesets = [cirq.CZTargetGateset()]
141186
metadata = cirq.GridDeviceMetadata(
142-
qubit_pairs, gateset, gate_durations=duration, all_qubits=qubits + isolated_qubits
187+
qubit_pairs,
188+
gateset,
189+
gate_durations=duration,
190+
all_qubits=qubits + isolated_qubits,
191+
compilation_target_gatesets=target_gatesets,
143192
)
144193
cirq.testing.assert_equivalent_repr(metadata)

0 commit comments

Comments
 (0)