15
15
from dataclasses import astuple , dataclass
16
16
from typing import (
17
17
Any ,
18
+ cast ,
18
19
Dict ,
19
20
List ,
20
21
Mapping ,
@@ -217,7 +218,7 @@ def _plot_colorbar(
217
218
)
218
219
position = self ._config ['colorbar_position' ]
219
220
orien = 'vertical' if position in ('left' , 'right' ) else 'horizontal'
220
- colorbar = ax .figure .colorbar (
221
+ colorbar = cast ( plt . Figure , ax .figure ) .colorbar (
221
222
mappable , colorbar_ax , ax , orientation = orien , ** self ._config .get ("colorbar_options" , {})
222
223
)
223
224
colorbar_ax .tick_params (axis = 'y' , direction = 'out' )
@@ -230,15 +231,15 @@ def _write_annotations(
230
231
ax : plt .Axes ,
231
232
) -> None :
232
233
"""Writes annotations to the center of cells. Internal."""
233
- for (center , annotation ), facecolor in zip (centers_and_annot , collection .get_facecolors ()):
234
+ for (center , annotation ), facecolor in zip (centers_and_annot , collection .get_facecolor ()):
234
235
# Calculate the center of the cell, assuming that it is a square
235
236
# centered at (x=col, y=row).
236
237
if not annotation :
237
238
continue
238
239
x , y = center
239
- face_luminance = vis_utils .relative_luminance (facecolor )
240
+ face_luminance = vis_utils .relative_luminance (facecolor ) # type: ignore
240
241
text_color = 'black' if face_luminance > 0.4 else 'white'
241
- text_kwargs = dict (color = text_color , ha = "center" , va = "center" )
242
+ text_kwargs : Dict [ str , Any ] = dict (color = text_color , ha = "center" , va = "center" )
242
243
text_kwargs .update (self ._config .get ('annotation_text_kwargs' , {}))
243
244
ax .text (x , y , annotation , ** text_kwargs )
244
245
@@ -295,6 +296,7 @@ def plot(
295
296
show_plot = not ax
296
297
if not ax :
297
298
fig , ax = plt .subplots (figsize = (8 , 8 ))
299
+ ax = cast (plt .Axes , ax )
298
300
original_config = copy .deepcopy (self ._config )
299
301
self .update_config (** kwargs )
300
302
collection = self ._plot_on_axis (ax )
@@ -381,6 +383,7 @@ def plot(
381
383
show_plot = not ax
382
384
if not ax :
383
385
fig , ax = plt .subplots (figsize = (8 , 8 ))
386
+ ax = cast (plt .Axes , ax )
384
387
original_config = copy .deepcopy (self ._config )
385
388
self .update_config (** kwargs )
386
389
qubits = set ([q for qubits in self ._value_map .keys () for q in qubits ])
0 commit comments