Skip to content

Commit 6a63eb8

Browse files
MichaelBroughtonrht
authored andcommitted
Add DeviceMetaData class. (quantumlib#4832)
Adds standalone DeviceMetaData class. First step in quantumlib#4743 .
1 parent e935432 commit 6a63eb8

File tree

8 files changed

+164
-8
lines changed

8 files changed

+164
-8
lines changed

cirq-core/cirq/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
from cirq.devices import (
8282
ConstantQubitNoiseModel,
8383
Device,
84+
DeviceMetadata,
8485
GridQid,
8586
GridQubit,
8687
LineQid,

cirq-core/cirq/devices/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Types for devices, device-specific qubits, and noise models."""
1616
from cirq.devices.device import (
1717
Device,
18+
DeviceMetadata,
1819
SymmetricalQidPair,
1920
)
2021

cirq-core/cirq/devices/device.py

+76-1
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# limitations under the License.
1414

1515
import abc
16-
from typing import TYPE_CHECKING, Optional, AbstractSet, cast, FrozenSet, Iterator
16+
from typing import TYPE_CHECKING, Optional, AbstractSet, cast, FrozenSet, Iterator, Iterable
1717

18+
import networkx as nx
1819
from cirq import value
1920
from cirq.devices.grid_qubit import _BaseGridQid
2021
from cirq.devices.line_qubit import _BaseLineQid
@@ -178,3 +179,77 @@ def __iter__(self) -> Iterator['cirq.Qid']:
178179

179180
def __contains__(self, item: 'cirq.Qid') -> bool:
180181
return item in self.qids
182+
183+
184+
@value.value_equality
185+
class DeviceMetadata:
186+
"""Parent type for all device specific metadata classes."""
187+
188+
def __init__(
189+
self,
190+
qubits: Optional[Iterable['cirq.Qid']] = None,
191+
nx_graph: Optional['nx.graph'] = None,
192+
):
193+
"""Construct a DeviceMetadata object.
194+
195+
Args:
196+
qubits: Optional iterable of `cirq.Qid`s that exist on the device.
197+
nx_graph: Optional `nx.Graph` describing qubit connectivity
198+
on a device. Nodes represent qubits, directed edges indicate
199+
directional coupling, undirected edges indicate bi-directional
200+
coupling.
201+
"""
202+
if qubits is not None:
203+
qubits = frozenset(qubits)
204+
self._qubits_set: Optional[FrozenSet['cirq.Qid']] = (
205+
None if qubits is None else frozenset(qubits)
206+
)
207+
208+
self._nx_graph = nx_graph
209+
210+
def qubit_set(self) -> Optional[FrozenSet['cirq.Qid']]:
211+
"""Returns a set of qubits on the device, if possible.
212+
213+
Returns:
214+
Frozenset of qubits on device if specified, otherwise None.
215+
"""
216+
return self._qubits_set
217+
218+
def nx_graph(self) -> Optional['nx.Graph']:
219+
"""Returns a nx.Graph where nodes are qubits and edges are couple-able qubits.
220+
221+
Returns:
222+
`nx.Graph` of device connectivity if specified, otherwise None.
223+
"""
224+
return self._nx_graph
225+
226+
def _value_equality_values_(self):
227+
graph_equality = None
228+
if self._nx_graph is not None:
229+
graph_equality = (sorted(self._nx_graph.nodes()), sorted(self._nx_graph.edges()))
230+
231+
qubit_equality = None
232+
if self._qubits_set is not None:
233+
qubit_equality = sorted(list(self._qubits_set))
234+
235+
return qubit_equality, graph_equality
236+
237+
def _json_dict_(self):
238+
graph_payload = ''
239+
if self._nx_graph is not None:
240+
graph_payload = nx.readwrite.json_graph.node_link_data(self._nx_graph)
241+
242+
qubits_payload = ''
243+
if self._qubits_set is not None:
244+
qubits_payload = sorted(list(self._qubits_set))
245+
246+
return {'qubits': qubits_payload, 'nx_graph': graph_payload}
247+
248+
@classmethod
249+
def _from_json_dict_(cls, qubits, nx_graph, **kwargs):
250+
if qubits == '':
251+
qubits = None
252+
graph_obj = None
253+
if nx_graph != '':
254+
graph_obj = nx.readwrite.json_graph.node_link_graph(nx_graph)
255+
return cls(qubits, graph_obj)

cirq-core/cirq/devices/device_test.py

+28
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# pylint: disable=wrong-or-nonexistent-copyright-notice
22
import pytest
3+
import networkx as nx
34
import cirq
45

56

@@ -75,3 +76,30 @@ def test_qid_pair():
7576

7677
with pytest.raises(ValueError, match='A QidPair cannot have identical qids.'):
7778
cirq.SymmetricalQidPair(q0, q0)
79+
80+
81+
def test_metadata():
82+
qubits = cirq.LineQubit.range(4)
83+
graph = nx.star_graph(3)
84+
metadata = cirq.DeviceMetadata(qubits, graph)
85+
assert metadata.qubit_set() == frozenset(qubits)
86+
assert metadata.nx_graph() == graph
87+
88+
metadata = cirq.DeviceMetadata()
89+
assert metadata.qubit_set() is None
90+
assert metadata.nx_graph() is None
91+
92+
93+
def test_metadata_json_load_logic():
94+
qubits = cirq.LineQubit.range(4)
95+
graph = nx.star_graph(3)
96+
metadata = cirq.DeviceMetadata(qubits, graph)
97+
str_rep = cirq.to_json(metadata)
98+
assert metadata == cirq.read_json(json_text=str_rep)
99+
100+
qubits = None
101+
graph = None
102+
metadata = cirq.DeviceMetadata(qubits, graph)
103+
str_rep = cirq.to_json(metadata)
104+
output = cirq.read_json(json_text=str_rep)
105+
assert metadata == output

cirq-core/cirq/json_resolver_cache.py

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def _parallel_gate_op(gate, qubits):
7777
'CZPowGate': cirq.CZPowGate,
7878
'DensePauliString': cirq.DensePauliString,
7979
'DepolarizingChannel': cirq.DepolarizingChannel,
80+
'DeviceMetadata': cirq.DeviceMetadata,
8081
'Duration': cirq.Duration,
8182
'FrozenCircuit': cirq.FrozenCircuit,
8283
'FSimGate': cirq.FSimGate,

cirq-core/cirq/protocols/json_serialization_test.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import ClassVar, Dict, List, Optional, Tuple, Type
2525
from unittest import mock
2626

27+
import networkx as nx
2728
import numpy as np
2829
import pandas as pd
2930
import pytest
@@ -726,13 +727,7 @@ def _eval_repr_data_file(path: pathlib.Path, deprecation_deadline: Optional[str]
726727
if deprecation is not None and deprecation.old_name in content:
727728
ctx_managers.append(deprecation.deprecation_assertion)
728729

729-
imports = {
730-
'cirq': cirq,
731-
'pd': pd,
732-
'sympy': sympy,
733-
'np': np,
734-
'datetime': datetime,
735-
}
730+
imports = {'cirq': cirq, 'pd': pd, 'sympy': sympy, 'np': np, 'datetime': datetime, 'nx': nx}
736731

737732
for m in TESTED_MODULES.keys():
738733
try:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
{
2+
"cirq_type": "DeviceMetadata",
3+
"qubits": [
4+
{
5+
"cirq_type": "LineQubit",
6+
"x": 0
7+
},
8+
{
9+
"cirq_type": "LineQubit",
10+
"x": 1
11+
},
12+
{
13+
"cirq_type": "LineQubit",
14+
"x": 2
15+
},
16+
{
17+
"cirq_type": "LineQubit",
18+
"x": 3
19+
}
20+
],
21+
"nx_graph": {
22+
"directed": false,
23+
"multigraph": false,
24+
"graph": {},
25+
"nodes": [
26+
{
27+
"id": 0
28+
},
29+
{
30+
"id": 1
31+
},
32+
{
33+
"id": 2
34+
},
35+
{
36+
"id": 3
37+
}
38+
],
39+
"links": [
40+
{
41+
"source": 0,
42+
"target": 1
43+
},
44+
{
45+
"source": 0,
46+
"target": 2
47+
},
48+
{
49+
"source": 0,
50+
"target": 3
51+
}
52+
]
53+
}
54+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
cirq.DeviceMetadata(cirq.LineQubit.range(4), nx.star_graph(3))

0 commit comments

Comments
 (0)