Skip to content

Commit f45e9ae

Browse files
committed
CompilationTargetGateset support in GridDeviceMetadata
1 parent 5ec8774 commit f45e9ae

File tree

4 files changed

+500
-148
lines changed

4 files changed

+500
-148
lines changed

cirq-core/cirq/devices/grid_device_metadata.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
gateset: 'cirq.Gateset',
4141
gate_durations: Optional[Dict['cirq.GateFamily', 'cirq.Duration']] = None,
4242
all_qubits: Optional[Iterable['cirq.Qid']] = None,
43+
compilation_target_gatesets: Iterable['cirq.CompilationTargetGateset'] = (),
4344
):
4445
"""Create a GridDeviceMetadata object.
4546
@@ -60,6 +61,10 @@ def __init__(
6061
all_qubits: Optional iterable specifying all qubits
6162
found on the device. If None, all_qubits will
6263
be inferred from the entries in qubit_pairs.
64+
compilation_target_gatesets: A collection of valid
65+
`cirq.CompilationTargetGateset`s which can be used to
66+
transform circuits into ones that consist of only
67+
operations in `gateset`.
6368
6469
Raises:
6570
ValueError: if the union of GateFamily keys in gate_durations
@@ -102,6 +107,7 @@ def __init__(
102107
self._qubit_pairs = frozenset(edge_set)
103108
self._gateset = gateset
104109
self._isolated_qubits = all_qubits.difference(node_set)
110+
self._compilation_target_gatesets = tuple(compilation_target_gatesets)
105111

106112
if gate_durations is not None:
107113
working_gatefamilies = frozenset(gate_durations.keys())
@@ -131,6 +137,11 @@ def gateset(self) -> 'cirq.Gateset':
131137
"""Returns the `cirq.Gateset` of supported gates on this device."""
132138
return self._gateset
133139

140+
@property
141+
def compilation_target_gatesets(self) -> Tuple['cirq.CompilationTargetGateset', ...]:
142+
"""Returns a sequence of valid `cirq.CompilationTargetGateset`s for this device."""
143+
return self._compilation_target_gatesets
144+
134145
@property
135146
def gate_durations(self) -> Optional[Dict['cirq.GateFamily', 'cirq.Duration']]:
136147
"""Get a dictionary mapping from gateset to duration for gates."""
@@ -146,13 +157,14 @@ def _value_equality_values_(self):
146157
self._gateset,
147158
tuple(duration_equality),
148159
tuple(sorted(self.qubit_set)),
160+
frozenset(self._compilation_target_gatesets),
149161
)
150162

151163
def __repr__(self) -> str:
152164
return (
153165
f'cirq.GridDeviceMetadata({repr(self._qubit_pairs)},'
154166
f' {repr(self._gateset)}, {repr(self._gate_durations)},'
155-
f' {repr(self.qubit_set)})'
167+
f' {repr(self.qubit_set)}, {repr(self._compilation_target_gatesets)})'
156168
)
157169

158170
def _json_dict_(self):
@@ -165,8 +177,23 @@ def _json_dict_(self):
165177
'gateset': self._gateset,
166178
'gate_durations': duration_payload,
167179
'all_qubits': sorted(list(self.qubit_set)),
180+
'compilation_target_gatesets': list(self._compilation_target_gatesets)
168181
}
169182

170183
@classmethod
171-
def _from_json_dict_(cls, qubit_pairs, gateset, gate_durations, all_qubits, **kwargs):
172-
return cls(qubit_pairs, gateset, dict(gate_durations), all_qubits)
184+
def _from_json_dict_(
185+
cls,
186+
qubit_pairs,
187+
gateset,
188+
gate_durations,
189+
all_qubits,
190+
compilation_target_gatesets=(),
191+
**kwargs,
192+
):
193+
return cls(
194+
qubit_pairs,
195+
gateset,
196+
None if gate_durations is None else dict(gate_durations),
197+
all_qubits,
198+
compilation_target_gatesets,
199+
)

cirq-core/cirq/devices/grid_device_metadata_test.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,13 @@ def test_griddevice_metadata():
2323
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
2424
isolated_qubits = [cirq.GridQubit(9, 9), cirq.GridQubit(10, 10)]
2525
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZ)
26-
metadata = cirq.GridDeviceMetadata(qubit_pairs, gateset, all_qubits=qubits + isolated_qubits)
26+
target_gatesets = (cirq.CZTargetGateset(),)
27+
metadata = cirq.GridDeviceMetadata(
28+
qubit_pairs,
29+
gateset,
30+
all_qubits=qubits + isolated_qubits,
31+
compilation_target_gatesets=target_gatesets,
32+
)
2733
expected_pairings = frozenset(
2834
{
2935
(cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)),
@@ -45,6 +51,7 @@ def test_griddevice_metadata():
4551
assert metadata.nx_graph.nodes() == expected_graph.nodes()
4652
assert metadata.gate_durations is None
4753
assert metadata.isolated_qubits == frozenset(isolated_qubits)
54+
assert metadata.compilation_target_gatesets == target_gatesets
4855

4956

5057
def test_griddevice_metadata_bad_durations():
@@ -80,35 +87,58 @@ def test_griddevice_self_loop():
8087
def test_griddevice_json_load():
8188
qubits = cirq.GridQubit.rect(2, 3)
8289
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
83-
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate)
90+
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZ)
8491
duration = {
8592
cirq.GateFamily(cirq.XPowGate): cirq.Duration(nanos=1),
8693
cirq.GateFamily(cirq.YPowGate): cirq.Duration(picos=2),
8794
cirq.GateFamily(cirq.ZPowGate): cirq.Duration(picos=3),
95+
cirq.GateFamily(cirq.CZ): cirq.Duration(nanos=4),
8896
}
8997
isolated_qubits = [cirq.GridQubit(9, 9), cirq.GridQubit(10, 10)]
98+
target_gatesets = [cirq.CZTargetGateset()]
9099
metadata = cirq.GridDeviceMetadata(
91-
qubit_pairs, gateset, gate_durations=duration, all_qubits=qubits + isolated_qubits
100+
qubit_pairs,
101+
gateset,
102+
gate_durations=duration,
103+
all_qubits=qubits + isolated_qubits,
104+
compilation_target_gatesets=target_gatesets,
92105
)
93106
rep_str = cirq.to_json(metadata)
94107
assert metadata == cirq.read_json(json_text=rep_str)
95108

96109

110+
def test_griddevice_json_load_with_defaults():
111+
qubits = cirq.GridQubit.rect(2, 3)
112+
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
113+
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZ)
114+
115+
# Don't set parameters with default values
116+
metadata = cirq.GridDeviceMetadata(qubit_pairs, gateset)
117+
rep_str = cirq.to_json(metadata)
118+
119+
assert metadata == cirq.read_json(json_text=rep_str)
120+
121+
97122
def test_griddevice_metadata_equality():
98123
qubits = cirq.GridQubit.rect(2, 3)
99124
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
100-
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate)
125+
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZ, cirq.SQRT_ISWAP)
101126
duration = {
102127
cirq.GateFamily(cirq.XPowGate): cirq.Duration(nanos=1),
103128
cirq.GateFamily(cirq.YPowGate): cirq.Duration(picos=3),
104129
cirq.GateFamily(cirq.ZPowGate): cirq.Duration(picos=2),
130+
cirq.GateFamily(cirq.CZ): cirq.Duration(nanos=4),
131+
cirq.GateFamily(cirq.SQRT_ISWAP): cirq.Duration(nanos=5),
105132
}
106133
duration2 = {
107134
cirq.GateFamily(cirq.XPowGate): cirq.Duration(nanos=10),
108135
cirq.GateFamily(cirq.YPowGate): cirq.Duration(picos=13),
109136
cirq.GateFamily(cirq.ZPowGate): cirq.Duration(picos=12),
137+
cirq.GateFamily(cirq.CZ): cirq.Duration(nanos=14),
138+
cirq.GateFamily(cirq.SQRT_ISWAP): cirq.Duration(nanos=15),
110139
}
111140
isolated_qubits = [cirq.GridQubit(9, 9)]
141+
target_gatesets = [cirq.CZTargetGateset(), cirq.SqrtIswapTargetGateset()]
112142
metadata = cirq.GridDeviceMetadata(qubit_pairs, gateset, gate_durations=duration)
113143
metadata2 = cirq.GridDeviceMetadata(qubit_pairs[:2], gateset, gate_durations=duration)
114144
metadata3 = cirq.GridDeviceMetadata(qubit_pairs, gateset, gate_durations=None)
@@ -117,28 +147,47 @@ def test_griddevice_metadata_equality():
117147
metadata6 = cirq.GridDeviceMetadata(
118148
qubit_pairs, gateset, gate_durations=duration, all_qubits=qubits + isolated_qubits
119149
)
150+
metadata7 = cirq.GridDeviceMetadata(
151+
qubit_pairs, gateset, compilation_target_gatesets=target_gatesets
152+
)
153+
metadata8 = cirq.GridDeviceMetadata(
154+
qubit_pairs, gateset, compilation_target_gatesets=target_gatesets[::-1]
155+
)
156+
metadata9 = cirq.GridDeviceMetadata(
157+
qubit_pairs, gateset, compilation_target_gatesets=tuple(target_gatesets)
158+
)
159+
metadata10 = cirq.GridDeviceMetadata(
160+
qubit_pairs, gateset, compilation_target_gatesets=set(target_gatesets)
161+
)
120162

121163
eq = cirq.testing.EqualsTester()
122164
eq.add_equality_group(metadata)
123165
eq.add_equality_group(metadata2)
124166
eq.add_equality_group(metadata3)
125167
eq.add_equality_group(metadata4)
126168
eq.add_equality_group(metadata6)
169+
eq.add_equality_group(metadata7, metadata8, metadata9, metadata10)
127170

128171
assert metadata == metadata5
129172

130173

131174
def test_repr():
132175
qubits = cirq.GridQubit.rect(2, 3)
133176
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
134-
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate)
177+
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZ)
135178
duration = {
136179
cirq.GateFamily(cirq.XPowGate): cirq.Duration(nanos=1),
137180
cirq.GateFamily(cirq.YPowGate): cirq.Duration(picos=3),
138181
cirq.GateFamily(cirq.ZPowGate): cirq.Duration(picos=2),
182+
cirq.GateFamily(cirq.CZ): cirq.Duration(nanos=4),
139183
}
140184
isolated_qubits = [cirq.GridQubit(9, 9)]
185+
target_gatesets = [cirq.CZTargetGateset()]
141186
metadata = cirq.GridDeviceMetadata(
142-
qubit_pairs, gateset, gate_durations=duration, all_qubits=qubits + isolated_qubits
187+
qubit_pairs,
188+
gateset,
189+
gate_durations=duration,
190+
all_qubits=qubits + isolated_qubits,
191+
compilation_target_gatesets=target_gatesets,
143192
)
144193
cirq.testing.assert_equivalent_repr(metadata)

0 commit comments

Comments
 (0)