Skip to content

Commit 04fd49c

Browse files
Update equalstester in DeviceMetadata. (quantumlib#4840)
Part of quantumlib#4743
1 parent 1a9c3d4 commit 04fd49c

File tree

2 files changed

+27
-10
lines changed

2 files changed

+27
-10
lines changed

cirq-core/cirq/devices/device.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def __init__(
207207

208208
self._nx_graph = nx_graph
209209

210+
@property
210211
def qubit_set(self) -> Optional[FrozenSet['cirq.Qid']]:
211212
"""Returns a set of qubits on the device, if possible.
212213
@@ -215,6 +216,7 @@ def qubit_set(self) -> Optional[FrozenSet['cirq.Qid']]:
215216
"""
216217
return self._qubits_set
217218

219+
@property
218220
def nx_graph(self) -> Optional['nx.Graph']:
219221
"""Returns a nx.Graph where nodes are qubits and edges are couple-able qubits.
220222
@@ -226,13 +228,12 @@ def nx_graph(self) -> Optional['nx.Graph']:
226228
def _value_equality_values_(self):
227229
graph_equality = None
228230
if self._nx_graph is not None:
229-
graph_equality = (sorted(self._nx_graph.nodes()), sorted(self._nx_graph.edges()))
230-
231-
qubit_equality = None
232-
if self._qubits_set is not None:
233-
qubit_equality = sorted(list(self._qubits_set))
231+
graph_equality = (
232+
tuple(sorted(self._nx_graph.nodes())),
233+
tuple(sorted(self._nx_graph.edges(data='directed'))),
234+
)
234235

235-
return qubit_equality, graph_equality
236+
return self._qubits_set, graph_equality
236237

237238
def _json_dict_(self):
238239
graph_payload = ''

cirq-core/cirq/devices/device_test.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ def test_metadata():
8282
qubits = cirq.LineQubit.range(4)
8383
graph = nx.star_graph(3)
8484
metadata = cirq.DeviceMetadata(qubits, graph)
85-
assert metadata.qubit_set() == frozenset(qubits)
86-
assert metadata.nx_graph() == graph
85+
assert metadata.qubit_set == frozenset(qubits)
86+
assert metadata.nx_graph == graph
8787

8888
metadata = cirq.DeviceMetadata()
89-
assert metadata.qubit_set() is None
90-
assert metadata.nx_graph() is None
89+
assert metadata.qubit_set is None
90+
assert metadata.nx_graph is None
9191

9292

9393
def test_metadata_json_load_logic():
@@ -103,3 +103,19 @@ def test_metadata_json_load_logic():
103103
str_rep = cirq.to_json(metadata)
104104
output = cirq.read_json(json_text=str_rep)
105105
assert metadata == output
106+
107+
108+
def test_metadata_equality():
109+
qubits = cirq.LineQubit.range(4)
110+
graph = nx.star_graph(3)
111+
graph2 = nx.star_graph(3)
112+
graph.add_edge(1, 2, directed=False)
113+
graph2.add_edge(1, 2, directed=True)
114+
115+
eq = cirq.testing.EqualsTester()
116+
eq.add_equality_group(cirq.DeviceMetadata(qubits, graph))
117+
eq.add_equality_group(cirq.DeviceMetadata(None, graph))
118+
eq.add_equality_group(cirq.DeviceMetadata(qubits, None))
119+
eq.add_equality_group(cirq.DeviceMetadata(None, None))
120+
121+
assert cirq.DeviceMetadata(None, graph) != cirq.DeviceMetadata(None, graph2)

0 commit comments

Comments
 (0)