-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Named Topologies #4370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Named Topologies #4370
Changes from 25 commits
b7b6fd4
538c823
50a9725
6ca0c0a
a2e5566
3e32e11
e1bd4d5
0f00f85
a17ce60
d8f12f4
b77f7ea
1ae35ba
866720d
e78d37c
9618d1f
9abaa4f
001062a
16016df
fd275c3
4bc9804
fff203c
bc428dd
88332e4
502b0aa
d5360f9
2f7ef60
43bf048
c89ef0e
ba2c6a5
af447a5
1cdb121
144e321
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,337 @@ | ||
# Copyright 2021 The Cirq Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import abc | ||
import dataclasses | ||
import warnings | ||
from dataclasses import dataclass | ||
from typing import Dict, List, Tuple, Any, Sequence, Union, Iterable, TYPE_CHECKING | ||
|
||
import networkx as nx | ||
from cirq.devices import GridQubit | ||
from cirq.protocols.json_serialization import obj_to_dict_helper | ||
from matplotlib import pyplot as plt | ||
|
||
if TYPE_CHECKING: | ||
import cirq | ||
|
||
|
||
def dataclass_json_dict(obj: Any, namespace: str = None) -> Dict[str, Any]: | ||
return obj_to_dict_helper(obj, [f.name for f in dataclasses.fields(obj)], namespace=namespace) | ||
|
||
|
||
class NamedTopology(metaclass=abc.ABCMeta): | ||
"""A topology (graph) with a name. | ||
|
||
"Named topologies" provide a mapping from a simple dataclass to a unique graph for categories | ||
of relevant topologies. Relevant topologies may be hardware dependant, but common topologies | ||
are linear (1D) and rectangular grid topologies. | ||
""" | ||
|
||
name: str = NotImplemented | ||
"""A name that uniquely identifies this topology.""" | ||
|
||
n_nodes: int = NotImplemented | ||
"""The number of nodes in the topology.""" | ||
|
||
graph: nx.Graph = NotImplemented | ||
"""A networkx graph representation of the topology.""" | ||
|
||
|
||
_GRIDLIKE_NODE = Union['cirq.GridQubit', Tuple[int, int]] | ||
|
||
|
||
def _node_and_coordinates( | ||
nodes: Iterable[_GRIDLIKE_NODE], | ||
) -> Iterable[Tuple[_GRIDLIKE_NODE, Tuple[int, int]]]: | ||
"""Yield tuples whose first element is the input node and the second is guaranteed to be a tuple | ||
of two integers. The input node can be a tuple of ints or a GridQubit.""" | ||
for node in nodes: | ||
if isinstance(node, GridQubit): | ||
yield node, (node.row, node.col) | ||
else: | ||
x, y = node | ||
yield node, (x, y) | ||
|
||
|
||
def draw_gridlike( | ||
graph: nx.Graph, ax: plt.Axes = None, tilted: bool = True, **kwargs | ||
) -> Dict[Any, Tuple[int, int]]: | ||
"""Draw a Grid-like graph. | ||
|
||
This wraps nx.draw_networkx to produce a matplotlib drawing of the graph. | ||
|
||
Args: | ||
graph: A NetworkX graph whose nodes are (row, column) coordinates. | ||
ax: Optional matplotlib axis to use for drawing. | ||
tilted: If True, directly position as (row, column); otherwise, | ||
rotate 45 degrees to accommodate google-style diagonal grids. | ||
kwargs: Additional arguments to pass to `nx.draw_networkx`. | ||
|
||
Returns: | ||
A positions dictionary mapping nodes to (x, y) coordinates suitable for future calls | ||
to NetworkX plotting functionality. | ||
""" | ||
if ax is None: | ||
ax = plt.gca() # coverage: ignore | ||
|
||
if tilted: | ||
pos = {node: (y, -x) for node, (x, y) in _node_and_coordinates(graph.nodes)} | ||
else: | ||
pos = {node: (x + y, y - x) for node, (x, y) in _node_and_coordinates(graph.nodes)} | ||
|
||
nx.draw_networkx(graph, pos=pos, ax=ax, **kwargs) | ||
ax.axis('equal') | ||
return pos | ||
|
||
|
||
@dataclass(frozen=True) | ||
class LineTopology(NamedTopology): | ||
"""A 1D linear topology. | ||
|
||
Node indices are contiguous integers starting from 0 with edges between | ||
adjacent integers. | ||
|
||
Args: | ||
n_nodes: The number of nodes in a line. | ||
""" | ||
|
||
n_nodes: int | ||
|
||
def __post_init__(self): | ||
if self.n_nodes <= 1: | ||
raise ValueError("`n_nodes` must be greater than 1.") | ||
object.__setattr__(self, 'name', f'line-{self.n_nodes}') | ||
graph = nx.from_edgelist( | ||
[(i1, i2) for i1, i2 in zip(range(self.n_nodes), range(1, self.n_nodes))] | ||
) | ||
object.__setattr__(self, 'graph', graph) | ||
|
||
def draw(self, ax=None, tilted: bool = True, **kwargs) -> Dict[Any, Tuple[int, int]]: | ||
"""Draw this graph. | ||
|
||
Args: | ||
ax: Optional matplotlib axis to use for drawing. | ||
tilted: If True, draw as a horizontal line. Otherwise, draw on a diagonal. | ||
kwargs: Additional arguments to pass to `nx.draw_networkx`. | ||
""" | ||
g2 = nx.relabel_nodes(self.graph, {n: (n, 1) for n in self.graph.nodes}) | ||
return draw_gridlike(g2, ax=ax, tilted=tilted, **kwargs) | ||
|
||
def _json_dict_(self) -> Dict[str, Any]: | ||
return dataclass_json_dict(self) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class TiltedSquareLattice(NamedTopology): | ||
"""A grid lattice rotated 45-degrees. | ||
|
||
This topology is based on Google devices where plaquettes consist of four qubits in a square | ||
connected to a central qubit: | ||
|
||
x x | ||
x | ||
x x | ||
|
||
The corner nodes are not connected to each other. `width` and `height` refer to the rectangle | ||
formed by rotating the lattice 45 degrees. `width` and `height` are measured in half-unit | ||
cells, or equivalently half the number of central nodes. | ||
An example diagram of this topology is shown below. It is a | ||
"tilted-square-lattice-6-4" with width 6 and height 4. | ||
|
||
x | ||
│ | ||
x────X────x | ||
│ │ │ | ||
x────X────x────X────x | ||
│ │ │ │ | ||
x────X────x────X───x | ||
│ │ │ | ||
x────X────x | ||
│ | ||
x | ||
|
||
Nodes are 2-tuples of integers which may be negative. Please see `get_placements` for | ||
mapping this topology to a GridQubit Device. | ||
""" | ||
|
||
width: int | ||
height: int | ||
|
||
def __post_init__(self): | ||
if self.width <= 0: | ||
raise ValueError("Width must be a positive integer") | ||
if self.height <= 0: | ||
raise ValueError("Height must be a positive integer") | ||
|
||
object.__setattr__(self, 'name', f'tilted-square-lattice-{self.width}-{self.height}') | ||
|
||
g = nx.Graph() | ||
|
||
def _add_edge(unit_row: int, unit_col: int, *, which: int): | ||
"""Helper function to add edges in 'unit cell coordinates'.""" | ||
y = unit_col + unit_row | ||
x = unit_col - unit_row | ||
|
||
if which == 0: | ||
# Either in the bulk or on a ragged boundary, we need this edge | ||
g.add_edge((x, y), (x, y - 1)) | ||
elif which == 1: | ||
# This is added in the bulk and for a "top" (extra height) ragged boundary | ||
g.add_edge((x, y), (x + 1, y)) | ||
elif which == 2: | ||
# This is added in the bulk and for a "side" (extra width) ragged boundary | ||
g.add_edge((x, y), (x - 1, y)) | ||
elif which == 3: | ||
# This is only added in the bulk. | ||
g.add_edge((x, y), (x, y + 1)) | ||
else: | ||
raise ValueError() # coverage: ignore | ||
|
||
# Iterate over unit cells, which are in units of 2*width, 2*height. | ||
# Add all all four edges when we're in the bulk. | ||
unit_cell_height = self.height // 2 | ||
unit_cell_width = self.width // 2 | ||
for unit_row in range(unit_cell_height): | ||
for unit_col in range(unit_cell_width): | ||
for i in range(4): | ||
_add_edge(unit_row, unit_col, which=i) | ||
|
||
extra_h = self.height % 2 | ||
if extra_h: | ||
# Add extra height to the final half-row. | ||
for unit_col in range(unit_cell_width): | ||
_add_edge(unit_cell_height, unit_col, which=0) | ||
_add_edge(unit_cell_height, unit_col, which=1) | ||
|
||
extra_w = self.width % 2 | ||
if extra_w: | ||
# Add extra width to the final half-column | ||
for unit_row in range(unit_cell_height): | ||
_add_edge(unit_row, unit_cell_width, which=0) | ||
_add_edge(unit_row, unit_cell_width, which=2) | ||
|
||
if extra_w and extra_h: | ||
# Add the final corner node when we have both ragged boundaries | ||
_add_edge(unit_cell_height, unit_cell_width, which=0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we might be able to tighten this up and just do something like: rect1 = set((i + j, i - j) for i in range(w//2 + 1) for j in range(h//2 + 1))
rect2 = set(((i + j) // 2, (i - j) // 2) for i in range(1, w + 1, 2) for j in range(1, h + 1, 2))
all_nodes = rect1 | rect2
g = nx.Graph()
for u in all_nodes:
for dx, dy in [(1,0), (-1, 0), (0,1), (0,-1)]:
v = (u[0] + dx, u[1] + dy)
if v in all_nodes:
g.add_edge(u, v) here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
object.__setattr__(self, 'graph', g) | ||
|
||
# The number of edges = width * height (see unit tests). This can be seen if you remove | ||
# all vertices and replace edges with dots. | ||
# The formula for the number of vertices is not that nice, but you can derive it by | ||
# summing big and small Xes in the asciiart in the docstring. | ||
# There are (width//2 + 1) * (height//2 + 1) small xes and | ||
# ((width + 1)//2) * ((height + 1)//2) big ones. | ||
n_nodes = (self.width // 2 + 1) * (self.height // 2 + 1) | ||
n_nodes += ((self.width + 1) // 2) * ((self.height + 1) // 2) | ||
object.__setattr__(self, 'n_nodes', n_nodes) | ||
Comment on lines
+205
to
+207
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could. I wanted a formula for the number of nodes in terms of width and height; the unit test verifies that the two methods agree. The potential benefit of a contributor being able to read off a formula by looking at the code offsets the downside of not using |
||
|
||
def draw(self, ax=None, tilted=True, **kwargs): | ||
"""Draw this graph | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: same comment here, plus missing period. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
Args: | ||
ax: Optional matplotlib axis to use for drawing. | ||
tilted: If True, directly position as (row, column); otherwise, | ||
rotate 45 degrees to accommodate the diagonal nature of this topology. | ||
kwargs: Additional arguments to pass to `nx.draw_networkx`. | ||
""" | ||
return draw_gridlike(self.graph, ax=ax, tilted=tilted, **kwargs) | ||
|
||
def nodes_as_gridqubits(self) -> List['cirq.GridQubit']: | ||
"""Get the graph nodes as cirq.GridQubit""" | ||
return [GridQubit(r, c) for r, c in sorted(self.graph.nodes)] | ||
|
||
def _json_dict_(self) -> Dict[str, Any]: | ||
return dataclass_json_dict(self) | ||
|
||
|
||
def get_placements( | ||
big_graph: nx.Graph, small_graph: nx.Graph, max_placements=100_000 | ||
) -> List[Dict]: | ||
"""Get 'placements' mapping small_graph nodes onto those of `big_graph`. | ||
|
||
We often consider the case where `big_graph` is a nx.Graph representation of a Device | ||
whose nodes are `cirq.Qid`s like `GridQubit`s and `small_graph` is a NamedTopology graph. | ||
In this case, this function returns a list of placement dictionaries. Each dictionary | ||
maps the nodes in `small_graph` to nodes in `big_graph` with a monomorphic relationship. | ||
That's to say: if an edge exists in `small_graph` between two nodes, it will exist in | ||
`big_graph` between the mapped nodes. | ||
|
||
We restrict only to unique set of `big_graph` qubits. Some monomorphisms may be basically | ||
the same mapping just rotated/flipped which we purposefully exclude. This could | ||
exclude meaningful differences like using the same qubits but having the edges assigned | ||
differently, but it prevents the number of placements from blowing up. | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we tighthen up this docstring a bit ? it's looking a little on the longer side and doesn't list the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactored into Args, Returns. Tightened up some of the wording. Kept most of the content. I think it's all pretty important. This is in a class of functions that doesn't correspond to the super pure general everything function: instead, it encodes general ideas I usually need when doing the particular form of graph matching for putting a circuit on a device. It might not be for everyone. The pure, general version would just be networkx's |
||
matcher = nx.algorithms.isomorphism.GraphMatcher(big_graph, small_graph) | ||
|
||
# de-duplicate rotations, see docstring. | ||
dedupe = {} | ||
for big_to_small_map in matcher.subgraph_monomorphisms_iter(): | ||
dedupe[frozenset(big_to_small_map.keys())] = big_to_small_map | ||
if len(dedupe) > max_placements: | ||
# coverage: ignore | ||
raise ValueError( | ||
f"We found more than {max_placements} placements. Please use a " | ||
f"more constraining `big_graph` or a more constrained `small_graph`." | ||
) | ||
|
||
small_to_bigs = [] | ||
for big in sorted(dedupe.keys()): | ||
big_to_small_map = dedupe[big] | ||
small_to_big_map = {v: k for k, v in big_to_small_map.items()} | ||
small_to_bigs.append(small_to_big_map) | ||
return small_to_bigs | ||
|
||
|
||
def draw_placements( | ||
big_graph: nx.Graph, | ||
small_graph: nx.Graph, | ||
small_to_big_mappings, | ||
max_plots=20, | ||
axes: Sequence[plt.Axes] = None, | ||
): | ||
"""Draw a visualization of placements from small_graph onto big_graph. | ||
|
||
The entire `big_graph` will be drawn with default blue colored nodes. `small_graph` nodes | ||
and edges will be highlighted with a red color. | ||
""" | ||
if len(small_to_big_mappings) > max_plots: | ||
# coverage: ignore | ||
warnings.warn(f"You've provided a lot of mappings. Only plotting the first {max_plots}") | ||
small_to_big_mappings = small_to_big_mappings[:max_plots] | ||
|
||
call_show = False | ||
if axes is None: | ||
# coverage: ignore | ||
call_show = True | ||
|
||
for i, small_to_big_map in enumerate(small_to_big_mappings): | ||
if axes is not None: | ||
ax = axes[i] | ||
else: | ||
# coverage: ignore | ||
ax = plt.gca() | ||
|
||
small_mapped = nx.relabel_nodes(small_graph, small_to_big_map) | ||
draw_gridlike(big_graph, ax=ax) | ||
draw_gridlike( | ||
small_mapped, node_color='red', edge_color='red', width=2, with_labels=False, ax=ax | ||
) | ||
ax.axis('equal') | ||
if call_show: | ||
# coverage: ignore | ||
# poor man's multi-axis figure: call plt.show() after each plot | ||
# and jupyter will put the plots one after another. | ||
plt.show() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add a small bit of detail here, such as
Draw this graph using matplotlib.
so that the docstring is not redundant with the function name.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added matplotlib notes here and everywhere