Skip to content

Commit b630298

Browse files
authored
Fix matplotlib typing (#6290)
* Fix matplotlib typing matplotlib 3.8.0 was released this week and included typing hints. This fixes the resulting CI breakages. * Fix issues. * formatting * Change to seaborn v0_8
1 parent f715527 commit b630298

File tree

13 files changed

+56
-32
lines changed

13 files changed

+56
-32
lines changed

cirq-core/cirq/contrib/svg/svg.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
from typing import TYPE_CHECKING, List, Tuple, cast, Dict
33

44
import matplotlib.textpath
5+
import matplotlib.font_manager
6+
57

68
if TYPE_CHECKING:
79
import cirq
810

911
QBLUE = '#1967d2'
10-
FONT = "Arial"
12+
FONT = matplotlib.font_manager.FontProperties(family="Arial")
1113
EMPTY_MOMENT_COLWIDTH = float(21) # assumed default column width
1214

1315

cirq-core/cirq/devices/named_topologies.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def _node_and_coordinates(
7474

7575

7676
def draw_gridlike(
77-
graph: nx.Graph, ax: plt.Axes = None, tilted: bool = True, **kwargs
77+
graph: nx.Graph, ax: Optional[plt.Axes] = None, tilted: bool = True, **kwargs
7878
) -> Dict[Any, Tuple[int, int]]:
7979
"""Draw a grid-like graph using Matplotlib.
8080

cirq-core/cirq/experiments/qubit_characterizations.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
import dataclasses
1616
import itertools
1717

18-
from typing import Any, Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING
18+
from typing import Any, cast, Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING
1919
import numpy as np
2020

2121
from matplotlib import pyplot as plt
2222

2323
# this is for older systems with matplotlib <3.2 otherwise 3d projections fail
24-
from mpl_toolkits import mplot3d # pylint: disable=unused-import
24+
from mpl_toolkits import mplot3d
2525
from cirq import circuits, ops, protocols
2626

2727
if TYPE_CHECKING:
@@ -89,8 +89,9 @@ def plot(self, ax: Optional[plt.Axes] = None, **plot_kwargs: Any) -> plt.Axes:
8989
"""
9090
show_plot = not ax
9191
if not ax:
92-
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
93-
ax.set_ylim([0, 1])
92+
fig, ax = plt.subplots(1, 1, figsize=(8, 8)) # pragma: no cover
93+
ax = cast(plt.Axes, ax) # pragma: no cover
94+
ax.set_ylim((0.0, 1.0)) # pragma: no cover
9495
ax.plot(self._num_cfds_seq, self._gnd_state_probs, 'ro-', **plot_kwargs)
9596
ax.set_xlabel(r"Number of Cliffords")
9697
ax.set_ylabel('Ground State Probability')
@@ -541,7 +542,7 @@ def _find_inv_matrix(mat: np.ndarray, mat_sequence: np.ndarray) -> int:
541542
def _matrix_bar_plot(
542543
mat: np.ndarray,
543544
z_label: str,
544-
ax: plt.Axes,
545+
ax: mplot3d.axes3d.Axes3D,
545546
kets: Optional[Sequence[str]] = None,
546547
title: Optional[str] = None,
547548
ylim: Tuple[int, int] = (-1, 1),

cirq-core/cirq/linalg/decompositions.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import (
2121
Any,
2222
Callable,
23+
cast,
2324
Iterable,
2425
List,
2526
Optional,
@@ -33,7 +34,7 @@
3334
import matplotlib.pyplot as plt
3435

3536
# this is for older systems with matplotlib <3.2 otherwise 3d projections fail
36-
from mpl_toolkits import mplot3d # pylint: disable=unused-import
37+
from mpl_toolkits import mplot3d
3738
import numpy as np
3839

3940
from cirq import value, protocols
@@ -554,7 +555,7 @@ def scatter_plot_normalized_kak_interaction_coefficients(
554555
interactions: Iterable[Union[np.ndarray, 'cirq.SupportsUnitary', 'KakDecomposition']],
555556
*,
556557
include_frame: bool = True,
557-
ax: Optional[plt.Axes] = None,
558+
ax: Optional[mplot3d.axes3d.Axes3D] = None,
558559
**kwargs,
559560
):
560561
r"""Plots the interaction coefficients of many two-qubit operations.
@@ -633,13 +634,13 @@ def scatter_plot_normalized_kak_interaction_coefficients(
633634
show_plot = not ax
634635
if not ax:
635636
fig = plt.figure()
636-
ax = fig.add_subplot(1, 1, 1, projection='3d')
637+
ax = cast(mplot3d.axes3d.Axes3D, fig.add_subplot(1, 1, 1, projection='3d'))
637638

638639
def coord_transform(
639640
pts: Union[List[Tuple[int, int, int]], np.ndarray]
640-
) -> Tuple[Iterable[float], Iterable[float], Iterable[float]]:
641+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
641642
if len(pts) == 0:
642-
return [], [], []
643+
return np.array([]), np.array([]), np.array([])
643644
xs, ys, zs = np.transpose(pts)
644645
return xs, zs, ys
645646

cirq-core/cirq/vis/heatmap.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from dataclasses import astuple, dataclass
1616
from typing import (
1717
Any,
18+
cast,
1819
Dict,
1920
List,
2021
Mapping,
@@ -217,7 +218,7 @@ def _plot_colorbar(
217218
)
218219
position = self._config['colorbar_position']
219220
orien = 'vertical' if position in ('left', 'right') else 'horizontal'
220-
colorbar = ax.figure.colorbar(
221+
colorbar = cast(plt.Figure, ax.figure).colorbar(
221222
mappable, colorbar_ax, ax, orientation=orien, **self._config.get("colorbar_options", {})
222223
)
223224
colorbar_ax.tick_params(axis='y', direction='out')
@@ -230,15 +231,15 @@ def _write_annotations(
230231
ax: plt.Axes,
231232
) -> None:
232233
"""Writes annotations to the center of cells. Internal."""
233-
for (center, annotation), facecolor in zip(centers_and_annot, collection.get_facecolors()):
234+
for (center, annotation), facecolor in zip(centers_and_annot, collection.get_facecolor()):
234235
# Calculate the center of the cell, assuming that it is a square
235236
# centered at (x=col, y=row).
236237
if not annotation:
237238
continue
238239
x, y = center
239-
face_luminance = vis_utils.relative_luminance(facecolor)
240+
face_luminance = vis_utils.relative_luminance(facecolor) # type: ignore
240241
text_color = 'black' if face_luminance > 0.4 else 'white'
241-
text_kwargs = dict(color=text_color, ha="center", va="center")
242+
text_kwargs: Dict[str, Any] = dict(color=text_color, ha="center", va="center")
242243
text_kwargs.update(self._config.get('annotation_text_kwargs', {}))
243244
ax.text(x, y, annotation, **text_kwargs)
244245

@@ -295,6 +296,7 @@ def plot(
295296
show_plot = not ax
296297
if not ax:
297298
fig, ax = plt.subplots(figsize=(8, 8))
299+
ax = cast(plt.Axes, ax)
298300
original_config = copy.deepcopy(self._config)
299301
self.update_config(**kwargs)
300302
collection = self._plot_on_axis(ax)
@@ -381,6 +383,7 @@ def plot(
381383
show_plot = not ax
382384
if not ax:
383385
fig, ax = plt.subplots(figsize=(8, 8))
386+
ax = cast(plt.Axes, ax)
384387
original_config = copy.deepcopy(self._config)
385388
self.update_config(**kwargs)
386389
qubits = set([q for qubits in self._value_map.keys() for q in qubits])

cirq-core/cirq/vis/heatmap_test.py

+10
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ def ax():
3434
return figure.add_subplot(111)
3535

3636

37+
def test_default_ax():
38+
row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8))
39+
test_value_map = {
40+
grid_qubit.GridQubit(row, col): np.random.random() for (row, col) in row_col_list
41+
}
42+
_, _ = heatmap.Heatmap(test_value_map).plot()
43+
44+
3745
@pytest.mark.parametrize('tuple_keys', [True, False])
3846
def test_cells_positions(ax, tuple_keys):
3947
row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8))
@@ -61,6 +69,8 @@ def test_two_qubit_heatmap(ax):
6169
title = "Two Qubit Interaction Heatmap"
6270
heatmap.TwoQubitInteractionHeatmap(value_map, title=title).plot(ax)
6371
assert ax.get_title() == title
72+
# Test default axis
73+
heatmap.TwoQubitInteractionHeatmap(value_map, title=title).plot()
6474

6575

6676
def test_invalid_args():

cirq-core/cirq/vis/histogram.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ def integrated_histogram(
100100
plot_options.update(kwargs)
101101

102102
if cdf_on_x:
103-
ax.step(bin_values, parameter_values, **plot_options)
103+
ax.step(bin_values, parameter_values, **plot_options) # type: ignore
104104
else:
105-
ax.step(parameter_values, bin_values, **plot_options)
105+
ax.step(parameter_values, bin_values, **plot_options) # type: ignore
106106

107107
set_semilog = ax.semilogy if cdf_on_x else ax.semilogx
108108
set_lim = ax.set_xlim if cdf_on_x else ax.set_ylim
@@ -128,15 +128,15 @@ def integrated_histogram(
128128

129129
if median_line:
130130
set_line(
131-
np.median(float_data),
131+
float(np.median(float_data)),
132132
linestyle='--',
133133
color=plot_options['color'],
134134
alpha=0.5,
135135
label=median_label,
136136
)
137137
if mean_line:
138138
set_line(
139-
np.mean(float_data),
139+
float(np.mean(float_data)),
140140
linestyle='-.',
141141
color=plot_options['color'],
142142
alpha=0.5,

cirq-core/cirq/vis/state_histogram.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Tool to visualize the results of a study."""
1616

17-
from typing import Union, Optional, Sequence, SupportsFloat
17+
from typing import cast, Optional, Sequence, SupportsFloat, Union
1818
import collections
1919
import numpy as np
2020
import matplotlib.pyplot as plt
@@ -51,13 +51,13 @@ def get_state_histogram(result: 'result.Result') -> np.ndarray:
5151

5252
def plot_state_histogram(
5353
data: Union['result.Result', collections.Counter, Sequence[SupportsFloat]],
54-
ax: Optional['plt.Axis'] = None,
54+
ax: Optional[plt.Axes] = None,
5555
*,
5656
tick_label: Optional[Sequence[str]] = None,
5757
xlabel: Optional[str] = 'qubit state',
5858
ylabel: Optional[str] = 'result count',
5959
title: Optional[str] = 'Result State Histogram',
60-
) -> 'plt.Axis':
60+
) -> plt.Axes:
6161
"""Plot the state histogram from either a single result with repetitions or
6262
a histogram computed using `result.histogram()` or a flattened histogram
6363
of measurement results computed using `get_state_histogram`.
@@ -87,6 +87,7 @@ def plot_state_histogram(
8787
show_fig = not ax
8888
if not ax:
8989
fig, ax = plt.subplots(1, 1)
90+
ax = cast(plt.Axes, ax)
9091
if isinstance(data, result.Result):
9192
values = get_state_histogram(data)
9293
elif isinstance(data, collections.Counter):
@@ -96,9 +97,12 @@ def plot_state_histogram(
9697
if tick_label is None:
9798
tick_label = [str(i) for i in range(len(values))]
9899
ax.bar(np.arange(len(values)), values, tick_label=tick_label)
99-
ax.set_xlabel(xlabel)
100-
ax.set_ylabel(ylabel)
101-
ax.set_title(title)
100+
if xlabel:
101+
ax.set_xlabel(xlabel)
102+
if ylabel:
103+
ax.set_ylabel(ylabel)
104+
if title:
105+
ax.set_title(title)
102106
if show_fig:
103107
fig.show()
104108
return ax

cirq-core/cirq/vis/state_histogram_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ def test_plot_state_histogram_result():
7878
for r1, r2 in zip(ax1.get_children(), ax2.get_children()):
7979
if isinstance(r1, mpl.patches.Rectangle) and isinstance(r2, mpl.patches.Rectangle):
8080
assert str(r1) == str(r2)
81+
# Test default axis
82+
state_histogram.plot_state_histogram(expected_values)
8183

8284

8385
@pytest.mark.usefixtures('closefigures')

cirq-google/cirq_google/engine/calibration.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from collections import abc, defaultdict
1818
import datetime
1919
from itertools import cycle
20-
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Sequence
20+
from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union, Sequence
2121

2222
import matplotlib as mpl
2323
import matplotlib.pyplot as plt
@@ -277,6 +277,7 @@ def plot_histograms(
277277
show_plot = not ax
278278
if not ax:
279279
fig, ax = plt.subplots(1, 1)
280+
ax = cast(plt.Axes, ax)
280281

281282
if isinstance(keys, str):
282283
keys = [keys]
@@ -322,7 +323,7 @@ def plot(
322323
show_plot = not fig
323324
if not fig:
324325
fig = plt.figure()
325-
axs = fig.subplots(1, 2)
326+
axs = cast(List[plt.Axes], fig.subplots(1, 2))
326327
self.heatmap(key).plot(axs[0])
327328
self.plot_histograms(key, axs[1])
328329
if show_plot:

docs/experiments/textbook_algorithms.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -1010,7 +1010,7 @@
10101010
"outputs": [],
10111011
"source": [
10121012
"\"\"\"Plot the results.\"\"\"\n",
1013-
"plt.style.use(\"seaborn-whitegrid\")\n",
1013+
"plt.style.use(\"seaborn-v0_8-whitegrid\")\n",
10141014
"\n",
10151015
"plt.plot(nvals, estimates, \"--o\", label=\"Phase estimation\")\n",
10161016
"plt.axhline(theta, label=\"True value\", color=\"black\")\n",

docs/start/intro.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -1453,7 +1453,7 @@
14531453
" probs.append(prob[0])\n",
14541454
"\n",
14551455
"# Plot the probability of the ground state at each simulation step.\n",
1456-
"plt.style.use('seaborn-whitegrid')\n",
1456+
"plt.style.use('seaborn-v0_8-whitegrid')\n",
14571457
"plt.plot(probs, 'o')\n",
14581458
"plt.xlabel(\"Step\")\n",
14591459
"plt.ylabel(\"Probability of ground state\");"
@@ -1490,7 +1490,7 @@
14901490
"\n",
14911491
"\n",
14921492
"# Plot the probability of the ground state at each simulation step.\n",
1493-
"plt.style.use('seaborn-whitegrid')\n",
1493+
"plt.style.use('seaborn-v0_8-whitegrid')\n",
14941494
"plt.plot(sampled_probs, 'o')\n",
14951495
"plt.xlabel(\"Step\")\n",
14961496
"plt.ylabel(\"Probability of ground state\");"

examples/two_qubit_gate_compilation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def main(samples: int = 1000, max_infidelity: float = 0.01):
8888
print(f'Maximum infidelity of "failed" compilation: {np.max(failed_infidelities_arr)}')
8989

9090
plt.figure()
91-
plt.hist(infidelities_arr, bins=25, range=[0, max_infidelity * 1.1])
91+
plt.hist(infidelities_arr, bins=25, range=(0.0, max_infidelity * 1.1)) # pragma: no cover
9292
ylim = plt.ylim()
9393
plt.plot([max_infidelity] * 2, ylim, '--', label='Maximum tabulation infidelity')
9494
plt.xlabel('Compiled gate infidelity vs target')

0 commit comments

Comments
 (0)