Skip to content

Commit 598d290

Browse files
committed
CompilationTargetGateset support in GridDeviceMetadata
1 parent c0c0e5f commit 598d290

6 files changed

+375
-10
lines changed

cirq-core/cirq/devices/grid_device_metadata.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
gateset: 'cirq.Gateset',
3434
gate_durations: Optional[Dict['cirq.GateFamily', 'cirq.Duration']] = None,
3535
all_qubits: Optional[Iterable['cirq.Qid']] = None,
36+
compilation_target_gatesets: Iterable['cirq.CompilationTargetGateset'] = (),
3637
):
3738
"""Create a GridDeviceMetadata object.
3839
@@ -53,6 +54,10 @@ def __init__(
5354
all_qubits: Optional iterable specifying all qubits
5455
found on the device. If None, all_qubits will
5556
be inferred from the entries in qubit_pairs.
57+
compilation_target_gatesets: A collection of valid
58+
`cirq.CompilationTargetGateset`s which can be used to
59+
transform circuits into ones that consist of only
60+
operations in `gateset`.
5661
5762
Raises:
5863
ValueError: if some GateFamily keys in gate_durations are
@@ -95,6 +100,7 @@ def __init__(
95100
self._qubit_pairs = frozenset({frozenset(pair) for pair in edge_set})
96101
self._gateset = gateset
97102
self._isolated_qubits = all_qubits.difference(node_set)
103+
self._compilation_target_gatesets = tuple(compilation_target_gatesets)
98104

99105
if gate_durations is not None:
100106
working_gatefamilies = frozenset(gate_durations.keys())
@@ -126,6 +132,11 @@ def gateset(self) -> 'cirq.Gateset':
126132
"""Returns the `cirq.Gateset` of supported gates on this device."""
127133
return self._gateset
128134

135+
@property
136+
def compilation_target_gatesets(self) -> Tuple['cirq.CompilationTargetGateset', ...]:
137+
"""Returns a sequence of valid `cirq.CompilationTargetGateset`s for this device."""
138+
return self._compilation_target_gatesets
139+
129140
@property
130141
def gate_durations(self) -> Optional[Dict['cirq.GateFamily', 'cirq.Duration']]:
131142
"""Get a dictionary mapping from gateset to duration for gates."""
@@ -141,14 +152,15 @@ def _value_equality_values_(self):
141152
self._gateset,
142153
tuple(duration_equality),
143154
tuple(sorted(self.qubit_set)),
155+
frozenset(self._compilation_target_gatesets),
144156
)
145157

146158
def __repr__(self) -> str:
147159
qubit_pair_tuples = frozenset({tuple(sorted(p)) for p in self._qubit_pairs})
148160
return (
149161
f'cirq.GridDeviceMetadata({repr(qubit_pair_tuples)},'
150162
f' {repr(self._gateset)}, {repr(self._gate_durations)},'
151-
f' {repr(self.qubit_set)})'
163+
f' {repr(self.qubit_set)}, {repr(self._compilation_target_gatesets)})'
152164
)
153165

154166
def _json_dict_(self):
@@ -161,8 +173,23 @@ def _json_dict_(self):
161173
'gateset': self._gateset,
162174
'gate_durations': duration_payload,
163175
'all_qubits': sorted(list(self.qubit_set)),
176+
'compilation_target_gatesets': list(self._compilation_target_gatesets),
164177
}
165178

166179
@classmethod
167-
def _from_json_dict_(cls, qubit_pairs, gateset, gate_durations, all_qubits, **kwargs):
168-
return cls(qubit_pairs, gateset, dict(gate_durations), all_qubits)
180+
def _from_json_dict_(
181+
cls,
182+
qubit_pairs,
183+
gateset,
184+
gate_durations,
185+
all_qubits,
186+
compilation_target_gatesets=(),
187+
**kwargs,
188+
):
189+
return cls(
190+
qubit_pairs,
191+
gateset,
192+
None if gate_durations is None else dict(gate_durations),
193+
all_qubits,
194+
compilation_target_gatesets,
195+
)

cirq-core/cirq/devices/grid_device_metadata_test.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,13 @@ def test_griddevice_metadata():
2929
cirq.GateFamily(cirq.ZPowGate): 1_000,
3030
# omitting cirq.CZ
3131
}
32+
target_gatesets = (cirq.CZTargetGateset(),)
3233
metadata = cirq.GridDeviceMetadata(
33-
qubit_pairs, gateset, gate_durations=gate_durations, all_qubits=qubits + isolated_qubits
34+
qubit_pairs,
35+
gateset,
36+
gate_durations=gate_durations,
37+
all_qubits=qubits + isolated_qubits,
38+
compilation_target_gatesets=target_gatesets,
3439
)
3540
expected_pairings = frozenset(
3641
{
@@ -53,6 +58,7 @@ def test_griddevice_metadata():
5358
assert metadata.nx_graph.nodes() == expected_graph.nodes()
5459
assert metadata.gate_durations == gate_durations
5560
assert metadata.isolated_qubits == frozenset(isolated_qubits)
61+
assert metadata.compilation_target_gatesets == target_gatesets
5662

5763

5864
def test_griddevice_metadata_bad_durations():
@@ -88,35 +94,58 @@ def test_griddevice_self_loop():
8894
def test_griddevice_json_load():
8995
qubits = cirq.GridQubit.rect(2, 3)
9096
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
91-
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate)
97+
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZ)
9298
duration = {
9399
cirq.GateFamily(cirq.XPowGate): cirq.Duration(nanos=1),
94100
cirq.GateFamily(cirq.YPowGate): cirq.Duration(picos=2),
95101
cirq.GateFamily(cirq.ZPowGate): cirq.Duration(picos=3),
102+
cirq.GateFamily(cirq.CZ): cirq.Duration(nanos=4),
96103
}
97104
isolated_qubits = [cirq.GridQubit(9, 9), cirq.GridQubit(10, 10)]
105+
target_gatesets = [cirq.CZTargetGateset()]
98106
metadata = cirq.GridDeviceMetadata(
99-
qubit_pairs, gateset, gate_durations=duration, all_qubits=qubits + isolated_qubits
107+
qubit_pairs,
108+
gateset,
109+
gate_durations=duration,
110+
all_qubits=qubits + isolated_qubits,
111+
compilation_target_gatesets=target_gatesets,
100112
)
101113
rep_str = cirq.to_json(metadata)
102114
assert metadata == cirq.read_json(json_text=rep_str)
103115

104116

117+
def test_griddevice_json_load_with_defaults():
118+
qubits = cirq.GridQubit.rect(2, 3)
119+
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
120+
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZ)
121+
122+
# Don't set parameters with default values
123+
metadata = cirq.GridDeviceMetadata(qubit_pairs, gateset)
124+
rep_str = cirq.to_json(metadata)
125+
126+
assert metadata == cirq.read_json(json_text=rep_str)
127+
128+
105129
def test_griddevice_metadata_equality():
106130
qubits = cirq.GridQubit.rect(2, 3)
107131
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
108-
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate)
132+
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZ, cirq.SQRT_ISWAP)
109133
duration = {
110134
cirq.GateFamily(cirq.XPowGate): cirq.Duration(nanos=1),
111135
cirq.GateFamily(cirq.YPowGate): cirq.Duration(picos=3),
112136
cirq.GateFamily(cirq.ZPowGate): cirq.Duration(picos=2),
137+
cirq.GateFamily(cirq.CZ): cirq.Duration(nanos=4),
138+
cirq.GateFamily(cirq.SQRT_ISWAP): cirq.Duration(nanos=5),
113139
}
114140
duration2 = {
115141
cirq.GateFamily(cirq.XPowGate): cirq.Duration(nanos=10),
116142
cirq.GateFamily(cirq.YPowGate): cirq.Duration(picos=13),
117143
cirq.GateFamily(cirq.ZPowGate): cirq.Duration(picos=12),
144+
cirq.GateFamily(cirq.CZ): cirq.Duration(nanos=14),
145+
cirq.GateFamily(cirq.SQRT_ISWAP): cirq.Duration(nanos=15),
118146
}
119147
isolated_qubits = [cirq.GridQubit(9, 9)]
148+
target_gatesets = [cirq.CZTargetGateset(), cirq.SqrtIswapTargetGateset()]
120149
metadata = cirq.GridDeviceMetadata(qubit_pairs, gateset, gate_durations=duration)
121150
metadata2 = cirq.GridDeviceMetadata(qubit_pairs[:2], gateset, gate_durations=duration)
122151
metadata3 = cirq.GridDeviceMetadata(qubit_pairs, gateset, gate_durations=None)
@@ -125,28 +154,47 @@ def test_griddevice_metadata_equality():
125154
metadata6 = cirq.GridDeviceMetadata(
126155
qubit_pairs, gateset, gate_durations=duration, all_qubits=qubits + isolated_qubits
127156
)
157+
metadata7 = cirq.GridDeviceMetadata(
158+
qubit_pairs, gateset, compilation_target_gatesets=target_gatesets
159+
)
160+
metadata8 = cirq.GridDeviceMetadata(
161+
qubit_pairs, gateset, compilation_target_gatesets=target_gatesets[::-1]
162+
)
163+
metadata9 = cirq.GridDeviceMetadata(
164+
qubit_pairs, gateset, compilation_target_gatesets=tuple(target_gatesets)
165+
)
166+
metadata10 = cirq.GridDeviceMetadata(
167+
qubit_pairs, gateset, compilation_target_gatesets=set(target_gatesets)
168+
)
128169

129170
eq = cirq.testing.EqualsTester()
130171
eq.add_equality_group(metadata)
131172
eq.add_equality_group(metadata2)
132173
eq.add_equality_group(metadata3)
133174
eq.add_equality_group(metadata4)
134175
eq.add_equality_group(metadata6)
176+
eq.add_equality_group(metadata7, metadata8, metadata9, metadata10)
135177

136178
assert metadata == metadata5
137179

138180

139181
def test_repr():
140182
qubits = cirq.GridQubit.rect(2, 3)
141183
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
142-
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate)
184+
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZ)
143185
duration = {
144186
cirq.GateFamily(cirq.XPowGate): cirq.Duration(nanos=1),
145187
cirq.GateFamily(cirq.YPowGate): cirq.Duration(picos=3),
146188
cirq.GateFamily(cirq.ZPowGate): cirq.Duration(picos=2),
189+
cirq.GateFamily(cirq.CZ): cirq.Duration(nanos=4),
147190
}
148191
isolated_qubits = [cirq.GridQubit(9, 9)]
192+
target_gatesets = [cirq.CZTargetGateset()]
149193
metadata = cirq.GridDeviceMetadata(
150-
qubit_pairs, gateset, gate_durations=duration, all_qubits=qubits + isolated_qubits
194+
qubit_pairs,
195+
gateset,
196+
gate_durations=duration,
197+
all_qubits=qubits + isolated_qubits,
198+
compilation_target_gatesets=target_gatesets,
151199
)
152200
cirq.testing.assert_equivalent_repr(metadata)

cirq-core/cirq/protocols/json_test_data/GridDeviceMetadata.json

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,76 @@
115115
"ignore_global_phase": true,
116116
"tags_to_accept": [],
117117
"tags_to_ignore": []
118+
},
119+
{
120+
"cirq_type": "GateFamily",
121+
"gate": {
122+
"cirq_type": "CZPowGate",
123+
"exponent": 1.0,
124+
"global_shift": 0.0
125+
},
126+
"name": "Instance GateFamily: CZ",
127+
"description": "Accepts `cirq.Gate` instances `g` s.t. `g == CZ`",
128+
"ignore_global_phase": true,
129+
"tags_to_accept": [],
130+
"tags_to_ignore": []
131+
},
132+
{
133+
"cirq_type": "GateFamily",
134+
"gate": {
135+
"cirq_type": "ISwapPowGate",
136+
"exponent": 0.5,
137+
"global_shift": 0.0
138+
},
139+
"name": "Instance GateFamily: ISWAP**0.5",
140+
"description": "Accepts `cirq.Gate` instances `g` s.t. `g == ISWAP**0.5`",
141+
"ignore_global_phase": true,
142+
"tags_to_accept": [],
143+
"tags_to_ignore": []
118144
}
119145
],
120146
"name": null,
121147
"unroll_circuit_op": true
122148
},
123149
"gate_durations": [
150+
[
151+
{
152+
"cirq_type": "GateFamily",
153+
"gate": {
154+
"cirq_type": "ISwapPowGate",
155+
"exponent": 0.5,
156+
"global_shift": 0.0
157+
},
158+
"name": "Instance GateFamily: ISWAP**0.5",
159+
"description": "Accepts `cirq.Gate` instances `g` s.t. `g == ISWAP**0.5`",
160+
"ignore_global_phase": true,
161+
"tags_to_accept": [],
162+
"tags_to_ignore": []
163+
},
164+
{
165+
"cirq_type": "Duration",
166+
"picos": 600
167+
}
168+
],
169+
[
170+
{
171+
"cirq_type": "GateFamily",
172+
"gate": {
173+
"cirq_type": "CZPowGate",
174+
"exponent": 1.0,
175+
"global_shift": 0.0
176+
},
177+
"name": "Instance GateFamily: CZ",
178+
"description": "Accepts `cirq.Gate` instances `g` s.t. `g == CZ`",
179+
"ignore_global_phase": true,
180+
"tags_to_accept": [],
181+
"tags_to_ignore": []
182+
},
183+
{
184+
"cirq_type": "Duration",
185+
"picos": 500
186+
}
187+
],
124188
[
125189
{
126190
"cirq_type": "GateFamily",
@@ -208,5 +272,18 @@
208272
"row": 10,
209273
"col": 10
210274
}
275+
],
276+
"compilation_target_gatesets": [
277+
{
278+
"cirq_type": "CZTargetGateset",
279+
"atol": 1e-08,
280+
"allow_partial_czs": false
281+
},
282+
{
283+
"cirq_type": "SqrtIswapTargetGateset",
284+
"atol": 1e-08,
285+
"required_sqrt_iswap_count": null,
286+
"use_sqrt_iswap_inv": false
287+
}
211288
]
212289
}

0 commit comments

Comments
 (0)