Skip to content

Commit d3d2252

Browse files
committed
CompilationTargetGateset support in GridDeviceMetadata
1 parent 614d4e0 commit d3d2252

File tree

4 files changed

+163
-10
lines changed

4 files changed

+163
-10
lines changed

cirq-core/cirq/devices/grid_device_metadata.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
gateset: 'cirq.Gateset',
3434
gate_durations: Optional[Dict['cirq.GateFamily', 'cirq.Duration']] = None,
3535
all_qubits: Optional[Iterable['cirq.Qid']] = None,
36+
compilation_target_gatesets: Iterable['cirq.CompilationTargetGateset'] = (),
3637
):
3738
"""Create a GridDeviceMetadata object.
3839
@@ -53,6 +54,10 @@ def __init__(
5354
all_qubits: Optional iterable specifying all qubits
5455
found on the device. If None, all_qubits will
5556
be inferred from the entries in qubit_pairs.
57+
compilation_target_gatesets: A collection of valid
58+
`cirq.CompilationTargetGateset`s which can be used to
59+
transform circuits into ones that consist of only
60+
operations in `gateset`.
5661
5762
Raises:
5863
ValueError: if the union of GateFamily keys in gate_durations
@@ -95,6 +100,7 @@ def __init__(
95100
self._qubit_pairs = frozenset({frozenset(pair) for pair in edge_set})
96101
self._gateset = gateset
97102
self._isolated_qubits = all_qubits.difference(node_set)
103+
self._compilation_target_gatesets = tuple(compilation_target_gatesets)
98104

99105
if gate_durations is not None:
100106
working_gatefamilies = frozenset(gate_durations.keys())
@@ -128,6 +134,11 @@ def gateset(self) -> 'cirq.Gateset':
128134
"""Returns the `cirq.Gateset` of supported gates on this device."""
129135
return self._gateset
130136

137+
@property
138+
def compilation_target_gatesets(self) -> Tuple['cirq.CompilationTargetGateset', ...]:
139+
"""Returns a sequence of valid `cirq.CompilationTargetGateset`s for this device."""
140+
return self._compilation_target_gatesets
141+
131142
@property
132143
def gate_durations(self) -> Optional[Dict['cirq.GateFamily', 'cirq.Duration']]:
133144
"""Get a dictionary mapping from gateset to duration for gates."""
@@ -143,14 +154,15 @@ def _value_equality_values_(self):
143154
self._gateset,
144155
tuple(duration_equality),
145156
tuple(sorted(self.qubit_set)),
157+
frozenset(self._compilation_target_gatesets),
146158
)
147159

148160
def __repr__(self) -> str:
149161
qubit_pair_tuples = frozenset({tuple(sorted(p)) for p in self._qubit_pairs})
150162
return (
151163
f'cirq.GridDeviceMetadata({repr(qubit_pair_tuples)},'
152164
f' {repr(self._gateset)}, {repr(self._gate_durations)},'
153-
f' {repr(self.qubit_set)})'
165+
f' {repr(self.qubit_set)}, {repr(self._compilation_target_gatesets)})'
154166
)
155167

156168
def _json_dict_(self):
@@ -163,8 +175,23 @@ def _json_dict_(self):
163175
'gateset': self._gateset,
164176
'gate_durations': duration_payload,
165177
'all_qubits': sorted(list(self.qubit_set)),
178+
'compilation_target_gatesets': list(self._compilation_target_gatesets),
166179
}
167180

168181
@classmethod
169-
def _from_json_dict_(cls, qubit_pairs, gateset, gate_durations, all_qubits, **kwargs):
170-
return cls(qubit_pairs, gateset, dict(gate_durations), all_qubits)
182+
def _from_json_dict_(
183+
cls,
184+
qubit_pairs,
185+
gateset,
186+
gate_durations,
187+
all_qubits,
188+
compilation_target_gatesets=(),
189+
**kwargs,
190+
):
191+
return cls(
192+
qubit_pairs,
193+
gateset,
194+
None if gate_durations is None else dict(gate_durations),
195+
all_qubits,
196+
compilation_target_gatesets,
197+
)

cirq-core/cirq/devices/grid_device_metadata_test.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,13 @@ def test_griddevice_metadata():
2323
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
2424
isolated_qubits = [cirq.GridQubit(9, 9), cirq.GridQubit(10, 10)]
2525
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZ)
26-
metadata = cirq.GridDeviceMetadata(qubit_pairs, gateset, all_qubits=qubits + isolated_qubits)
26+
target_gatesets = (cirq.CZTargetGateset(),)
27+
metadata = cirq.GridDeviceMetadata(
28+
qubit_pairs,
29+
gateset,
30+
all_qubits=qubits + isolated_qubits,
31+
compilation_target_gatesets=target_gatesets,
32+
)
2733
expected_pairings = frozenset(
2834
{
2935
frozenset((cirq.GridQubit(0, 0), cirq.GridQubit(0, 1))),
@@ -45,6 +51,7 @@ def test_griddevice_metadata():
4551
assert metadata.nx_graph.nodes() == expected_graph.nodes()
4652
assert metadata.gate_durations is None
4753
assert metadata.isolated_qubits == frozenset(isolated_qubits)
54+
assert metadata.compilation_target_gatesets == target_gatesets
4855

4956

5057
def test_griddevice_metadata_bad_durations():
@@ -80,35 +87,58 @@ def test_griddevice_self_loop():
8087
def test_griddevice_json_load():
8188
qubits = cirq.GridQubit.rect(2, 3)
8289
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
83-
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate)
90+
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZ)
8491
duration = {
8592
cirq.GateFamily(cirq.XPowGate): cirq.Duration(nanos=1),
8693
cirq.GateFamily(cirq.YPowGate): cirq.Duration(picos=2),
8794
cirq.GateFamily(cirq.ZPowGate): cirq.Duration(picos=3),
95+
cirq.GateFamily(cirq.CZ): cirq.Duration(nanos=4),
8896
}
8997
isolated_qubits = [cirq.GridQubit(9, 9), cirq.GridQubit(10, 10)]
98+
target_gatesets = [cirq.CZTargetGateset()]
9099
metadata = cirq.GridDeviceMetadata(
91-
qubit_pairs, gateset, gate_durations=duration, all_qubits=qubits + isolated_qubits
100+
qubit_pairs,
101+
gateset,
102+
gate_durations=duration,
103+
all_qubits=qubits + isolated_qubits,
104+
compilation_target_gatesets=target_gatesets,
92105
)
93106
rep_str = cirq.to_json(metadata)
94107
assert metadata == cirq.read_json(json_text=rep_str)
95108

96109

110+
def test_griddevice_json_load_with_defaults():
111+
qubits = cirq.GridQubit.rect(2, 3)
112+
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
113+
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZ)
114+
115+
# Don't set parameters with default values
116+
metadata = cirq.GridDeviceMetadata(qubit_pairs, gateset)
117+
rep_str = cirq.to_json(metadata)
118+
119+
assert metadata == cirq.read_json(json_text=rep_str)
120+
121+
97122
def test_griddevice_metadata_equality():
98123
qubits = cirq.GridQubit.rect(2, 3)
99124
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
100-
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate)
125+
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZ, cirq.SQRT_ISWAP)
101126
duration = {
102127
cirq.GateFamily(cirq.XPowGate): cirq.Duration(nanos=1),
103128
cirq.GateFamily(cirq.YPowGate): cirq.Duration(picos=3),
104129
cirq.GateFamily(cirq.ZPowGate): cirq.Duration(picos=2),
130+
cirq.GateFamily(cirq.CZ): cirq.Duration(nanos=4),
131+
cirq.GateFamily(cirq.SQRT_ISWAP): cirq.Duration(nanos=5),
105132
}
106133
duration2 = {
107134
cirq.GateFamily(cirq.XPowGate): cirq.Duration(nanos=10),
108135
cirq.GateFamily(cirq.YPowGate): cirq.Duration(picos=13),
109136
cirq.GateFamily(cirq.ZPowGate): cirq.Duration(picos=12),
137+
cirq.GateFamily(cirq.CZ): cirq.Duration(nanos=14),
138+
cirq.GateFamily(cirq.SQRT_ISWAP): cirq.Duration(nanos=15),
110139
}
111140
isolated_qubits = [cirq.GridQubit(9, 9)]
141+
target_gatesets = [cirq.CZTargetGateset(), cirq.SqrtIswapTargetGateset()]
112142
metadata = cirq.GridDeviceMetadata(qubit_pairs, gateset, gate_durations=duration)
113143
metadata2 = cirq.GridDeviceMetadata(qubit_pairs[:2], gateset, gate_durations=duration)
114144
metadata3 = cirq.GridDeviceMetadata(qubit_pairs, gateset, gate_durations=None)
@@ -117,28 +147,47 @@ def test_griddevice_metadata_equality():
117147
metadata6 = cirq.GridDeviceMetadata(
118148
qubit_pairs, gateset, gate_durations=duration, all_qubits=qubits + isolated_qubits
119149
)
150+
metadata7 = cirq.GridDeviceMetadata(
151+
qubit_pairs, gateset, compilation_target_gatesets=target_gatesets
152+
)
153+
metadata8 = cirq.GridDeviceMetadata(
154+
qubit_pairs, gateset, compilation_target_gatesets=target_gatesets[::-1]
155+
)
156+
metadata9 = cirq.GridDeviceMetadata(
157+
qubit_pairs, gateset, compilation_target_gatesets=tuple(target_gatesets)
158+
)
159+
metadata10 = cirq.GridDeviceMetadata(
160+
qubit_pairs, gateset, compilation_target_gatesets=set(target_gatesets)
161+
)
120162

121163
eq = cirq.testing.EqualsTester()
122164
eq.add_equality_group(metadata)
123165
eq.add_equality_group(metadata2)
124166
eq.add_equality_group(metadata3)
125167
eq.add_equality_group(metadata4)
126168
eq.add_equality_group(metadata6)
169+
eq.add_equality_group(metadata7, metadata8, metadata9, metadata10)
127170

128171
assert metadata == metadata5
129172

130173

131174
def test_repr():
132175
qubits = cirq.GridQubit.rect(2, 3)
133176
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
134-
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate)
177+
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZ)
135178
duration = {
136179
cirq.GateFamily(cirq.XPowGate): cirq.Duration(nanos=1),
137180
cirq.GateFamily(cirq.YPowGate): cirq.Duration(picos=3),
138181
cirq.GateFamily(cirq.ZPowGate): cirq.Duration(picos=2),
182+
cirq.GateFamily(cirq.CZ): cirq.Duration(nanos=4),
139183
}
140184
isolated_qubits = [cirq.GridQubit(9, 9)]
185+
target_gatesets = [cirq.CZTargetGateset()]
141186
metadata = cirq.GridDeviceMetadata(
142-
qubit_pairs, gateset, gate_durations=duration, all_qubits=qubits + isolated_qubits
187+
qubit_pairs,
188+
gateset,
189+
gate_durations=duration,
190+
all_qubits=qubits + isolated_qubits,
191+
compilation_target_gatesets=target_gatesets,
143192
)
144193
cirq.testing.assert_equivalent_repr(metadata)

cirq-core/cirq/protocols/json_test_data/GridDeviceMetadata.json

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,76 @@
115115
"ignore_global_phase": true,
116116
"tags_to_accept": [],
117117
"tags_to_ignore": []
118+
},
119+
{
120+
"cirq_type": "GateFamily",
121+
"gate": {
122+
"cirq_type": "CZPowGate",
123+
"exponent": 1.0,
124+
"global_shift": 0.0
125+
},
126+
"name": "Instance GateFamily: CZ",
127+
"description": "Accepts `cirq.Gate` instances `g` s.t. `g == CZ`",
128+
"ignore_global_phase": true,
129+
"tags_to_accept": [],
130+
"tags_to_ignore": []
131+
},
132+
{
133+
"cirq_type": "GateFamily",
134+
"gate": {
135+
"cirq_type": "ISwapPowGate",
136+
"exponent": 0.5,
137+
"global_shift": 0.0
138+
},
139+
"name": "Instance GateFamily: ISWAP**0.5",
140+
"description": "Accepts `cirq.Gate` instances `g` s.t. `g == ISWAP**0.5`",
141+
"ignore_global_phase": true,
142+
"tags_to_accept": [],
143+
"tags_to_ignore": []
118144
}
119145
],
120146
"name": null,
121147
"unroll_circuit_op": true
122148
},
123149
"gate_durations": [
150+
[
151+
{
152+
"cirq_type": "GateFamily",
153+
"gate": {
154+
"cirq_type": "ISwapPowGate",
155+
"exponent": 0.5,
156+
"global_shift": 0.0
157+
},
158+
"name": "Instance GateFamily: ISWAP**0.5",
159+
"description": "Accepts `cirq.Gate` instances `g` s.t. `g == ISWAP**0.5`",
160+
"ignore_global_phase": true,
161+
"tags_to_accept": [],
162+
"tags_to_ignore": []
163+
},
164+
{
165+
"cirq_type": "Duration",
166+
"picos": 600
167+
}
168+
],
169+
[
170+
{
171+
"cirq_type": "GateFamily",
172+
"gate": {
173+
"cirq_type": "CZPowGate",
174+
"exponent": 1.0,
175+
"global_shift": 0.0
176+
},
177+
"name": "Instance GateFamily: CZ",
178+
"description": "Accepts `cirq.Gate` instances `g` s.t. `g == CZ`",
179+
"ignore_global_phase": true,
180+
"tags_to_accept": [],
181+
"tags_to_ignore": []
182+
},
183+
{
184+
"cirq_type": "Duration",
185+
"picos": 500
186+
}
187+
],
124188
[
125189
{
126190
"cirq_type": "GateFamily",
@@ -208,5 +272,18 @@
208272
"row": 10,
209273
"col": 10
210274
}
275+
],
276+
"compilation_target_gatesets": [
277+
{
278+
"cirq_type": "CZTargetGateset",
279+
"atol": 1e-08,
280+
"allow_partial_czs": false
281+
},
282+
{
283+
"cirq_type": "SqrtIswapTargetGateset",
284+
"atol": 1e-08,
285+
"required_sqrt_iswap_count": null,
286+
"use_sqrt_iswap_inv": false
287+
}
211288
]
212289
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
cirq.GridDeviceMetadata(frozenset({(cirq.GridQubit(0, 1), cirq.GridQubit(0, 2)), (cirq.GridQubit(1, 1), cirq.GridQubit(1, 2)), (cirq.GridQubit(1, 0), cirq.GridQubit(1, 1)), (cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)), (cirq.GridQubit(0, 0), cirq.GridQubit(1, 0)), (cirq.GridQubit(0, 2), cirq.GridQubit(1, 2)), (cirq.GridQubit(0, 1), cirq.GridQubit(1, 1))}), cirq.Gateset(cirq.ops.common_gates.XPowGate, cirq.ops.common_gates.YPowGate, cirq.ops.common_gates.ZPowGate, unroll_circuit_op = True), {cirq.GateFamily(gate=cirq.ops.common_gates.XPowGate, ignore_global_phase=True): cirq.Duration(nanos=1), cirq.GateFamily(gate=cirq.ops.common_gates.YPowGate, ignore_global_phase=True): cirq.Duration(picos=1), cirq.GateFamily(gate=cirq.ops.common_gates.ZPowGate, ignore_global_phase=True): cirq.Duration(picos=1)}, frozenset({cirq.GridQubit(0, 0), cirq.GridQubit(1, 0), cirq.GridQubit(1, 2), cirq.GridQubit(10, 10), cirq.GridQubit(0, 2), cirq.GridQubit(0, 1), cirq.GridQubit(1, 1), cirq.GridQubit(9, 9)}))
1+
cirq.GridDeviceMetadata(frozenset({(cirq.GridQubit(1, 1), cirq.GridQubit(1, 2)), (cirq.GridQubit(0, 0), cirq.GridQubit(1, 0)), (cirq.GridQubit(1, 0), cirq.GridQubit(1, 1)), (cirq.GridQubit(0, 1), cirq.GridQubit(1, 1)), (cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)), (cirq.GridQubit(0, 2), cirq.GridQubit(1, 2)), (cirq.GridQubit(0, 1), cirq.GridQubit(0, 2))}), cirq.Gateset(cirq.ops.common_gates.XPowGate, cirq.ops.common_gates.YPowGate, cirq.ops.common_gates.ZPowGate, cirq.CZ, (cirq.ISWAP**0.5), unroll_circuit_op = True), {cirq.GateFamily(gate=cirq.ops.common_gates.XPowGate, ignore_global_phase=True, tags_to_accept=frozenset(), tags_to_ignore=frozenset()): cirq.Duration(nanos=1), cirq.GateFamily(gate=cirq.ops.common_gates.YPowGate, ignore_global_phase=True, tags_to_accept=frozenset(), tags_to_ignore=frozenset()): cirq.Duration(picos=1), cirq.GateFamily(gate=cirq.ops.common_gates.ZPowGate, ignore_global_phase=True, tags_to_accept=frozenset(), tags_to_ignore=frozenset()): cirq.Duration(picos=1), cirq.GateFamily(gate=cirq.CZ, ignore_global_phase=True, tags_to_accept=frozenset(), tags_to_ignore=frozenset()): cirq.Duration(picos=500), cirq.GateFamily(gate=(cirq.ISWAP**0.5), ignore_global_phase=True, tags_to_accept=frozenset(), tags_to_ignore=frozenset()): cirq.Duration(picos=600)}, frozenset({cirq.GridQubit(1, 0), cirq.GridQubit(0, 2), cirq.GridQubit(9, 9), cirq.GridQubit(10, 10), cirq.GridQubit(0, 1), cirq.GridQubit(1, 1), cirq.GridQubit(0, 0), cirq.GridQubit(1, 2)}), (cirq.CZTargetGateset(atol=1e-08, allow_partial_czs=False), cirq.SqrtIswapTargetGateset(atol=1e-08, required_sqrt_iswap_count=None, use_sqrt_iswap_inv=False)))

0 commit comments

Comments
 (0)