Skip to content

Commit df9c9e8

Browse files
mpharriganrht
authored andcommitted

12 files changed

+748
-1
lines changed

cirq-core/cirq/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@
9191
NoiseModel,
9292
SymmetricalQidPair,
9393
UNCONSTRAINED_DEVICE,
94+
NamedTopology,
95+
draw_gridlike,
96+
LineTopology,
97+
TiltedSquareLattice,
98+
get_placements,
99+
draw_placements,
94100
)
95101

96102
from cirq.experiments import (

cirq-core/cirq/devices/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,12 @@
3838
NoiseModel,
3939
ConstantQubitNoiseModel,
4040
)
41+
42+
from cirq.devices.named_topologies import (
43+
NamedTopology,
44+
draw_gridlike,
45+
LineTopology,
46+
TiltedSquareLattice,
47+
get_placements,
48+
draw_placements,
49+
)
+316
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
# Copyright 2021 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import abc
16+
import dataclasses
17+
import warnings
18+
from dataclasses import dataclass
19+
from typing import Dict, List, Tuple, Any, Sequence, Union, Iterable, TYPE_CHECKING
20+
21+
import networkx as nx
22+
from cirq.devices import GridQubit
23+
from cirq.protocols.json_serialization import obj_to_dict_helper
24+
from matplotlib import pyplot as plt
25+
26+
if TYPE_CHECKING:
27+
import cirq
28+
29+
30+
def dataclass_json_dict(obj: Any, namespace: str = None) -> Dict[str, Any]:
31+
return obj_to_dict_helper(obj, [f.name for f in dataclasses.fields(obj)], namespace=namespace)
32+
33+
34+
class NamedTopology(metaclass=abc.ABCMeta):
35+
"""A topology (graph) with a name.
36+
37+
"Named topologies" provide a mapping from a simple dataclass to a unique graph for categories
38+
of relevant topologies. Relevant topologies may be hardware dependant, but common topologies
39+
are linear (1D) and rectangular grid topologies.
40+
"""
41+
42+
name: str = NotImplemented
43+
"""A name that uniquely identifies this topology."""
44+
45+
n_nodes: int = NotImplemented
46+
"""The number of nodes in the topology."""
47+
48+
graph: nx.Graph = NotImplemented
49+
"""A networkx graph representation of the topology."""
50+
51+
52+
_GRIDLIKE_NODE = Union['cirq.GridQubit', Tuple[int, int]]
53+
54+
55+
def _node_and_coordinates(
56+
nodes: Iterable[_GRIDLIKE_NODE],
57+
) -> Iterable[Tuple[_GRIDLIKE_NODE, Tuple[int, int]]]:
58+
"""Yield tuples whose first element is the input node and the second is guaranteed to be a tuple
59+
of two integers. The input node can be a tuple of ints or a GridQubit."""
60+
for node in nodes:
61+
if isinstance(node, GridQubit):
62+
yield node, (node.row, node.col)
63+
else:
64+
x, y = node
65+
yield node, (x, y)
66+
67+
68+
def draw_gridlike(
69+
graph: nx.Graph, ax: plt.Axes = None, tilted: bool = True, **kwargs
70+
) -> Dict[Any, Tuple[int, int]]:
71+
"""Draw a grid-like graph using Matplotlib.
72+
73+
This wraps nx.draw_networkx to produce a matplotlib drawing of the graph. Nodes
74+
should be two-dimensional gridlike objects.
75+
76+
Args:
77+
graph: A NetworkX graph whose nodes are (row, column) coordinates or cirq.GridQubits.
78+
ax: Optional matplotlib axis to use for drawing.
79+
tilted: If True, directly position as (row, column); otherwise,
80+
rotate 45 degrees to accommodate google-style diagonal grids.
81+
kwargs: Additional arguments to pass to `nx.draw_networkx`.
82+
83+
Returns:
84+
A positions dictionary mapping nodes to (x, y) coordinates suitable for future calls
85+
to NetworkX plotting functionality.
86+
"""
87+
if ax is None:
88+
ax = plt.gca() # coverage: ignore
89+
90+
if tilted:
91+
pos = {node: (y, -x) for node, (x, y) in _node_and_coordinates(graph.nodes)}
92+
else:
93+
pos = {node: (x + y, y - x) for node, (x, y) in _node_and_coordinates(graph.nodes)}
94+
95+
nx.draw_networkx(graph, pos=pos, ax=ax, **kwargs)
96+
ax.axis('equal')
97+
return pos
98+
99+
100+
@dataclass(frozen=True)
101+
class LineTopology(NamedTopology):
102+
"""A 1D linear topology.
103+
104+
Node indices are contiguous integers starting from 0 with edges between
105+
adjacent integers.
106+
107+
Args:
108+
n_nodes: The number of nodes in a line.
109+
"""
110+
111+
n_nodes: int
112+
113+
def __post_init__(self):
114+
if self.n_nodes <= 1:
115+
raise ValueError("`n_nodes` must be greater than 1.")
116+
object.__setattr__(self, 'name', f'line-{self.n_nodes}')
117+
graph = nx.from_edgelist(
118+
[(i1, i2) for i1, i2 in zip(range(self.n_nodes), range(1, self.n_nodes))]
119+
)
120+
object.__setattr__(self, 'graph', graph)
121+
122+
def draw(self, ax=None, tilted: bool = True, **kwargs) -> Dict[Any, Tuple[int, int]]:
123+
"""Draw this graph using Matplotlib.
124+
125+
Args:
126+
ax: Optional matplotlib axis to use for drawing.
127+
tilted: If True, draw as a horizontal line. Otherwise, draw on a diagonal.
128+
kwargs: Additional arguments to pass to `nx.draw_networkx`.
129+
"""
130+
g2 = nx.relabel_nodes(self.graph, {n: (n, 1) for n in self.graph.nodes})
131+
return draw_gridlike(g2, ax=ax, tilted=tilted, **kwargs)
132+
133+
def _json_dict_(self) -> Dict[str, Any]:
134+
return dataclass_json_dict(self)
135+
136+
137+
@dataclass(frozen=True)
138+
class TiltedSquareLattice(NamedTopology):
139+
"""A grid lattice rotated 45-degrees.
140+
141+
This topology is based on Google devices where plaquettes consist of four qubits in a square
142+
connected to a central qubit:
143+
144+
x x
145+
x
146+
x x
147+
148+
The corner nodes are not connected to each other. `width` and `height` refer to the rectangle
149+
formed by rotating the lattice 45 degrees. `width` and `height` are measured in half-unit
150+
cells, or equivalently half the number of central nodes.
151+
An example diagram of this topology is shown below. It is a
152+
"tilted-square-lattice-6-4" with width 6 and height 4.
153+
154+
x
155+
156+
x────X────x
157+
│ │ │
158+
x────X────x────X────x
159+
│ │ │ │
160+
x────X────x────X───x
161+
│ │ │
162+
x────X────x
163+
164+
x
165+
166+
Nodes are 2-tuples of integers which may be negative. Please see `get_placements` for
167+
mapping this topology to a GridQubit Device.
168+
"""
169+
170+
width: int
171+
height: int
172+
173+
def __post_init__(self):
174+
if self.width <= 0:
175+
raise ValueError("Width must be a positive integer")
176+
if self.height <= 0:
177+
raise ValueError("Height must be a positive integer")
178+
179+
object.__setattr__(self, 'name', f'tilted-square-lattice-{self.width}-{self.height}')
180+
181+
rect1 = set(
182+
(i + j, i - j) for i in range(self.width // 2 + 1) for j in range(self.height // 2 + 1)
183+
)
184+
rect2 = set(
185+
((i + j) // 2, (i - j) // 2)
186+
for i in range(1, self.width + 1, 2)
187+
for j in range(1, self.height + 1, 2)
188+
)
189+
nodes = rect1 | rect2
190+
g = nx.Graph()
191+
for node in nodes:
192+
for dx, dy in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
193+
neighbor = (node[0] + dx, node[1] + dy)
194+
if neighbor in nodes:
195+
g.add_edge(node, neighbor)
196+
197+
object.__setattr__(self, 'graph', g)
198+
199+
# The number of edges = width * height (see unit tests). This can be seen if you remove
200+
# all vertices and replace edges with dots.
201+
# The formula for the number of vertices is not that nice, but you can derive it by
202+
# summing big and small Xes in the asciiart in the docstring.
203+
# There are (width//2 + 1) * (height//2 + 1) small xes and
204+
# ((width + 1)//2) * ((height + 1)//2) big ones.
205+
n_nodes = (self.width // 2 + 1) * (self.height // 2 + 1)
206+
n_nodes += ((self.width + 1) // 2) * ((self.height + 1) // 2)
207+
object.__setattr__(self, 'n_nodes', n_nodes)
208+
209+
def draw(self, ax=None, tilted=True, **kwargs):
210+
"""Draw this graph using Matplotlib.
211+
212+
Args:
213+
ax: Optional matplotlib axis to use for drawing.
214+
tilted: If True, directly position as (row, column); otherwise,
215+
rotate 45 degrees to accommodate the diagonal nature of this topology.
216+
kwargs: Additional arguments to pass to `nx.draw_networkx`.
217+
"""
218+
return draw_gridlike(self.graph, ax=ax, tilted=tilted, **kwargs)
219+
220+
def nodes_as_gridqubits(self) -> List['cirq.GridQubit']:
221+
"""Get the graph nodes as cirq.GridQubit"""
222+
return [GridQubit(r, c) for r, c in sorted(self.graph.nodes)]
223+
224+
def _json_dict_(self) -> Dict[str, Any]:
225+
return dataclass_json_dict(self)
226+
227+
228+
def get_placements(
229+
big_graph: nx.Graph, small_graph: nx.Graph, max_placements=100_000
230+
) -> List[Dict]:
231+
"""Get 'placements' mapping small_graph nodes onto those of `big_graph`.
232+
233+
This function considers monomorphisms with a restriction: we restrict only to unique set
234+
of `big_graph` qubits. Some monomorphisms may be basically
235+
the same mapping just rotated/flipped which we purposefully exclude. This could
236+
exclude meaningful differences like using the same qubits but having the edges assigned
237+
differently, but it prevents the number of placements from blowing up.
238+
239+
Args:
240+
big_graph: The parent, super-graph. We often consider the case where this is a
241+
nx.Graph representation of a Device whose nodes are `cirq.Qid`s like `GridQubit`s.
242+
small_graph: The subgraph. We often consider the case where this is a NamedTopology
243+
graph.
244+
max_placements: Raise a value error if there are more than this many placement
245+
possibilities. It is possible to use `big_graph`, `small_graph` combinations
246+
that result in an intractable number of placements.
247+
248+
Raises:
249+
ValueError: if the number of placements exceeds `max_placements`.
250+
251+
Returns:
252+
A list of placement dictionaries. Each dictionary maps the nodes in `small_graph` to
253+
nodes in `big_graph` with a monomorphic relationship. That's to say: if an edge exists
254+
in `small_graph` between two nodes, it will exist in `big_graph` between the mapped nodes.
255+
"""
256+
matcher = nx.algorithms.isomorphism.GraphMatcher(big_graph, small_graph)
257+
258+
# de-duplicate rotations, see docstring.
259+
dedupe = {}
260+
for big_to_small_map in matcher.subgraph_monomorphisms_iter():
261+
dedupe[frozenset(big_to_small_map.keys())] = big_to_small_map
262+
if len(dedupe) > max_placements:
263+
# coverage: ignore
264+
raise ValueError(
265+
f"We found more than {max_placements} placements. Please use a "
266+
f"more constraining `big_graph` or a more constrained `small_graph`."
267+
)
268+
269+
small_to_bigs = []
270+
for big in sorted(dedupe.keys()):
271+
big_to_small_map = dedupe[big]
272+
small_to_big_map = {v: k for k, v in big_to_small_map.items()}
273+
small_to_bigs.append(small_to_big_map)
274+
return small_to_bigs
275+
276+
277+
def draw_placements(
278+
big_graph: nx.Graph,
279+
small_graph: nx.Graph,
280+
small_to_big_mappings,
281+
max_plots=20,
282+
axes: Sequence[plt.Axes] = None,
283+
):
284+
"""Draw a visualization of placements from small_graph onto big_graph using Matplotlib.
285+
286+
The entire `big_graph` will be drawn with default blue colored nodes. `small_graph` nodes
287+
and edges will be highlighted with a red color.
288+
"""
289+
if len(small_to_big_mappings) > max_plots:
290+
# coverage: ignore
291+
warnings.warn(f"You've provided a lot of mappings. Only plotting the first {max_plots}")
292+
small_to_big_mappings = small_to_big_mappings[:max_plots]
293+
294+
call_show = False
295+
if axes is None:
296+
# coverage: ignore
297+
call_show = True
298+
299+
for i, small_to_big_map in enumerate(small_to_big_mappings):
300+
if axes is not None:
301+
ax = axes[i]
302+
else:
303+
# coverage: ignore
304+
ax = plt.gca()
305+
306+
small_mapped = nx.relabel_nodes(small_graph, small_to_big_map)
307+
draw_gridlike(big_graph, ax=ax)
308+
draw_gridlike(
309+
small_mapped, node_color='red', edge_color='red', width=2, with_labels=False, ax=ax
310+
)
311+
ax.axis('equal')
312+
if call_show:
313+
# coverage: ignore
314+
# poor man's multi-axis figure: call plt.show() after each plot
315+
# and jupyter will put the plots one after another.
316+
plt.show()

0 commit comments

Comments
 (0)