Skip to content

Commit 59066a5

Browse files
Add GridDeviceMetadata. (quantumlib#4839)
Adds GridDeviceMetadata implementation as part of quantumlib#4743 .
1 parent 525cff0 commit 59066a5

8 files changed

+450
-1
lines changed

cirq/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
ConstantQubitNoiseModel,
8383
Device,
8484
DeviceMetadata,
85+
GridDeviceMetadata,
8586
GridQid,
8687
GridQubit,
8788
LineQid,

cirq/devices/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
SymmetricalQidPair,
2020
)
2121

22+
from cirq.devices.grid_device_metadata import (
23+
GridDeviceMetadata,
24+
)
25+
2226
from cirq.devices.grid_qubit import (
2327
GridQid,
2428
GridQubit,

cirq/devices/device.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,15 @@
1313
# limitations under the License.
1414

1515
import abc
16-
from typing import TYPE_CHECKING, Optional, AbstractSet, cast, FrozenSet, Iterator, Iterable
16+
from typing import (
17+
TYPE_CHECKING,
18+
Optional,
19+
AbstractSet,
20+
cast,
21+
FrozenSet,
22+
Iterator,
23+
Iterable,
24+
)
1725

1826
import networkx as nx
1927
from cirq import value

cirq/devices/grid_device_metadata.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright 2022 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Metadata subtype for 2D Homogenous devices."""
15+
16+
from typing import (
17+
TYPE_CHECKING,
18+
Optional,
19+
FrozenSet,
20+
Iterable,
21+
Tuple,
22+
Dict,
23+
)
24+
25+
import networkx as nx
26+
from cirq import value
27+
from cirq.devices import device
28+
29+
if TYPE_CHECKING:
30+
import cirq
31+
32+
33+
@value.value_equality
34+
class GridDeviceMetadata(device.DeviceMetadata):
35+
"""Hardware metadata for homogenous 2d symmetric grid devices."""
36+
37+
def __init__(
38+
self,
39+
qubit_pairs: Iterable[Tuple['cirq.Qid', 'cirq.Qid']],
40+
supported_gates: 'cirq.Gateset',
41+
gate_durations: Optional[Dict['cirq.Gateset', 'cirq.Duration']] = None,
42+
):
43+
"""Create a GridDeviceMetadata object.
44+
45+
Create a GridDevice which has a well defined set of couplable
46+
qubit pairs that have the same two qubit gates available in
47+
both coupling directions.
48+
49+
Args:
50+
qubit_pairs: Iterable of pairs of `cirq.Qid`s representing
51+
bi-directional couplings.
52+
supported_gates: `cirq.Gateset` indicating gates supported
53+
everywhere on the device.
54+
gate_durations: Optional dictionary of `cirq.Gateset`
55+
instances mapping to `cirq.Duration` instances for
56+
gate timing metadata information. If provided,
57+
must match all entries in supported_gates.
58+
59+
Raises:
60+
ValueError: if the union of gateset keys in gate_durations,
61+
do not represent an identical gateset to supported_gates.
62+
"""
63+
qubit_pairs = list(qubit_pairs)
64+
flat_pairs = [q for pair in qubit_pairs for q in pair]
65+
# Keep lexigraphically smaller tuples for undirected edges.
66+
sorted_pairs = sorted(qubit_pairs)
67+
pair_set = set()
68+
for a, b in sorted_pairs:
69+
if (b, a) not in pair_set:
70+
pair_set.add((a, b))
71+
72+
connectivity = nx.Graph()
73+
connectivity.add_edges_from(sorted(pair_set), directed=False)
74+
super().__init__(flat_pairs, connectivity)
75+
self._qubit_pairs = frozenset(pair_set)
76+
self._supported_gates = supported_gates
77+
78+
if gate_durations is not None:
79+
working_gatefamilies = frozenset(
80+
g for gset in gate_durations.keys() for g in gset.gates
81+
)
82+
if working_gatefamilies != supported_gates.gates:
83+
missing_items = working_gatefamilies.difference(supported_gates.gates)
84+
raise ValueError(
85+
"Supplied gate_durations contains gates not present"
86+
f" in supported_gates. {missing_items} in supported_gates"
87+
" is False."
88+
)
89+
90+
self._gate_durations = gate_durations
91+
92+
@property
93+
def qubit_pairs(self) -> FrozenSet[Tuple['cirq.Qid', 'cirq.Qid']]:
94+
"""Returns the set of all couple-able qubits on the device."""
95+
return self._qubit_pairs
96+
97+
@property
98+
def gateset(self) -> 'cirq.Gateset':
99+
"""Returns the `cirq.Gateset` of supported gates on this device."""
100+
return self._supported_gates
101+
102+
@property
103+
def gate_durations(self) -> Optional[Dict['cirq.Gateset', 'cirq.Duration']]:
104+
"""Get a dictionary mapping from gateset to duration for gates."""
105+
return self._gate_durations
106+
107+
def _value_equality_values_(self):
108+
duration_equality = ''
109+
if self._gate_durations is not None:
110+
duration_equality = sorted(self._gate_durations.items(), key=lambda x: repr(x[0]))
111+
112+
return (
113+
tuple(sorted(self._qubit_pairs)),
114+
self._supported_gates,
115+
tuple(duration_equality),
116+
)
117+
118+
def __repr__(self) -> str:
119+
return (
120+
f'cirq.GridDeviceMetadata({repr(self._qubit_pairs)},'
121+
f' {repr(self._supported_gates)}, {repr(self._gate_durations)})'
122+
)
123+
124+
def _json_dict_(self):
125+
duration_payload = None
126+
if self._gate_durations is not None:
127+
duration_payload = sorted(self._gate_durations.items(), key=lambda x: repr(x[0]))
128+
129+
return {
130+
'qubit_pairs': sorted(list(self._qubit_pairs)),
131+
'supported_gates': self._supported_gates,
132+
'gate_durations': duration_payload,
133+
}
134+
135+
@classmethod
136+
def _from_json_dict_(cls, qubit_pairs, supported_gates, gate_durations, **kwargs):
137+
return cls(qubit_pairs, supported_gates, dict(gate_durations))
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2022 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for GridDevicemetadata."""
15+
16+
import pytest
17+
import cirq
18+
import networkx as nx
19+
20+
21+
def test_griddevice_metadata():
22+
qubits = cirq.GridQubit.rect(2, 3)
23+
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
24+
25+
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZ)
26+
metadata = cirq.GridDeviceMetadata(qubit_pairs, gateset)
27+
28+
expected_pairings = frozenset(
29+
{
30+
(cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)),
31+
(cirq.GridQubit(0, 1), cirq.GridQubit(0, 2)),
32+
(cirq.GridQubit(0, 1), cirq.GridQubit(1, 1)),
33+
(cirq.GridQubit(0, 2), cirq.GridQubit(1, 2)),
34+
(cirq.GridQubit(1, 0), cirq.GridQubit(1, 1)),
35+
(cirq.GridQubit(1, 1), cirq.GridQubit(1, 2)),
36+
(cirq.GridQubit(0, 0), cirq.GridQubit(1, 0)),
37+
}
38+
)
39+
assert metadata.qubit_set == frozenset(qubits)
40+
assert metadata.qubit_pairs == expected_pairings
41+
assert metadata.gateset == gateset
42+
expected_graph = nx.Graph()
43+
expected_graph.add_edges_from(sorted(list(expected_pairings)), directed=False)
44+
assert metadata.nx_graph.edges() == expected_graph.edges()
45+
assert metadata.nx_graph.nodes() == expected_graph.nodes()
46+
assert metadata.gate_durations is None
47+
48+
49+
def test_griddevice_metadata_bad_durations():
50+
qubits = tuple(cirq.GridQubit.rect(1, 2))
51+
52+
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate)
53+
invalid_duration = {
54+
cirq.Gateset(cirq.XPowGate): cirq.Duration(nanos=1),
55+
cirq.Gateset(cirq.ZPowGate): cirq.Duration(picos=1),
56+
}
57+
with pytest.raises(ValueError, match="ZPowGate"):
58+
cirq.GridDeviceMetadata([qubits], gateset, gate_durations=invalid_duration)
59+
60+
61+
def test_griddevice_json_load():
62+
qubits = cirq.GridQubit.rect(2, 3)
63+
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
64+
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate)
65+
duration = {
66+
cirq.Gateset(cirq.XPowGate): cirq.Duration(nanos=1),
67+
cirq.Gateset(cirq.YPowGate): cirq.Duration(picos=2),
68+
cirq.Gateset(cirq.ZPowGate): cirq.Duration(picos=3),
69+
}
70+
metadata = cirq.GridDeviceMetadata(qubit_pairs, gateset, gate_durations=duration)
71+
rep_str = cirq.to_json(metadata)
72+
assert metadata == cirq.read_json(json_text=rep_str)
73+
74+
75+
def test_griddevice_metadata_equality():
76+
qubits = cirq.GridQubit.rect(2, 3)
77+
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
78+
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate)
79+
duration = {
80+
cirq.Gateset(cirq.XPowGate): cirq.Duration(nanos=1),
81+
cirq.Gateset(cirq.YPowGate): cirq.Duration(picos=3),
82+
cirq.Gateset(cirq.ZPowGate): cirq.Duration(picos=2),
83+
}
84+
duration2 = {
85+
cirq.Gateset(cirq.XPowGate): cirq.Duration(nanos=10),
86+
cirq.Gateset(cirq.YPowGate): cirq.Duration(picos=13),
87+
cirq.Gateset(cirq.ZPowGate): cirq.Duration(picos=12),
88+
}
89+
metadata = cirq.GridDeviceMetadata(qubit_pairs, gateset, gate_durations=duration)
90+
metadata2 = cirq.GridDeviceMetadata(qubit_pairs[:2], gateset, gate_durations=duration)
91+
metadata3 = cirq.GridDeviceMetadata(qubit_pairs, gateset, gate_durations=None)
92+
metadata4 = cirq.GridDeviceMetadata(qubit_pairs, gateset, gate_durations=duration2)
93+
metadata5 = cirq.GridDeviceMetadata(reversed(qubit_pairs), gateset, gate_durations=duration)
94+
95+
eq = cirq.testing.EqualsTester()
96+
eq.add_equality_group(metadata)
97+
eq.add_equality_group(metadata2)
98+
eq.add_equality_group(metadata3)
99+
eq.add_equality_group(metadata4)
100+
101+
assert metadata == metadata5
102+
103+
104+
def test_repr():
105+
qubits = cirq.GridQubit.rect(2, 3)
106+
qubit_pairs = [(a, b) for a in qubits for b in qubits if a != b and a.is_adjacent(b)]
107+
gateset = cirq.Gateset(cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate)
108+
duration = {
109+
cirq.Gateset(cirq.XPowGate): cirq.Duration(nanos=1),
110+
cirq.Gateset(cirq.YPowGate): cirq.Duration(picos=3),
111+
cirq.Gateset(cirq.ZPowGate): cirq.Duration(picos=2),
112+
}
113+
metadata = cirq.GridDeviceMetadata(qubit_pairs, gateset, gate_durations=duration)
114+
cirq.testing.assert_equivalent_repr(metadata)

cirq/json_resolver_cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def _parallel_gate_op(gate, qubits):
8787
'GeneralizedAmplitudeDampingChannel': cirq.GeneralizedAmplitudeDampingChannel,
8888
'GlobalPhaseGate': cirq.GlobalPhaseGate,
8989
'GlobalPhaseOperation': cirq.GlobalPhaseOperation,
90+
'GridDeviceMetadata': cirq.GridDeviceMetadata,
9091
'GridInteractionLayer': GridInteractionLayer,
9192
'GridParallelXEBMetadata': GridParallelXEBMetadata,
9293
'GridQid': cirq.GridQid,

0 commit comments

Comments
 (0)