Skip to content

Commit ea28861

Browse files
authored
Single matplotlib import (#5794)
1 parent b791558 commit ea28861

File tree

5 files changed

+31
-42
lines changed

5 files changed

+31
-42
lines changed

asv_bench/benchmarks/import_xarray.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
class ImportXarray:
2+
def setup(self, *args, **kwargs):
3+
def import_xr():
4+
import xarray # noqa: F401
5+
6+
self._import_xr = import_xr
7+
8+
def time_import_xarray(self):
9+
self._import_xr()

xarray/plot/dataset_plot.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
_process_cmap_cbar_kwargs,
1313
get_axis,
1414
label_from_attrs,
15+
plt,
1516
)
1617

1718
# copied from seaborn
@@ -134,8 +135,7 @@ def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None)
134135

135136
# copied from seaborn
136137
def _parse_size(data, norm):
137-
138-
import matplotlib as mpl
138+
mpl = plt.matplotlib
139139

140140
if data is None:
141141
return None
@@ -544,8 +544,6 @@ def quiver(ds, x, y, ax, u, v, **kwargs):
544544
545545
Wraps :py:func:`matplotlib:matplotlib.pyplot.quiver`.
546546
"""
547-
import matplotlib as mpl
548-
549547
if x is None or y is None or u is None or v is None:
550548
raise ValueError("Must specify x, y, u, v for quiver plots.")
551549

@@ -560,7 +558,7 @@ def quiver(ds, x, y, ax, u, v, **kwargs):
560558

561559
# TODO: Fix this by always returning a norm with vmin, vmax in cmap_params
562560
if not cmap_params["norm"]:
563-
cmap_params["norm"] = mpl.colors.Normalize(
561+
cmap_params["norm"] = plt.Normalize(
564562
cmap_params.pop("vmin"), cmap_params.pop("vmax")
565563
)
566564

@@ -576,8 +574,6 @@ def streamplot(ds, x, y, ax, u, v, **kwargs):
576574
577575
Wraps :py:func:`matplotlib:matplotlib.pyplot.streamplot`.
578576
"""
579-
import matplotlib as mpl
580-
581577
if x is None or y is None or u is None or v is None:
582578
raise ValueError("Must specify x, y, u, v for streamplot plots.")
583579

@@ -613,7 +609,7 @@ def streamplot(ds, x, y, ax, u, v, **kwargs):
613609

614610
# TODO: Fix this by always returning a norm with vmin, vmax in cmap_params
615611
if not cmap_params["norm"]:
616-
cmap_params["norm"] = mpl.colors.Normalize(
612+
cmap_params["norm"] = plt.Normalize(
617613
cmap_params.pop("vmin"), cmap_params.pop("vmax")
618614
)
619615

xarray/plot/facetgrid.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
_get_nice_quiver_magnitude,
1010
_infer_xy_labels,
1111
_process_cmap_cbar_kwargs,
12-
import_matplotlib_pyplot,
1312
label_from_attrs,
13+
plt,
1414
)
1515

1616
# Overrides axes.labelsize, xtick.major.size, ytick.major.size
@@ -116,8 +116,6 @@ def __init__(
116116
117117
"""
118118

119-
plt = import_matplotlib_pyplot()
120-
121119
# Handle corner case of nonunique coordinates
122120
rep_col = col is not None and not data[col].to_index().is_unique
123121
rep_row = row is not None and not data[row].to_index().is_unique
@@ -519,10 +517,8 @@ def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwar
519517
self: FacetGrid object
520518
521519
"""
522-
import matplotlib as mpl
523-
524520
if size is None:
525-
size = mpl.rcParams["axes.labelsize"]
521+
size = plt.rcParams["axes.labelsize"]
526522

527523
nicetitle = functools.partial(_nicetitle, maxchar=maxchar, template=template)
528524

@@ -619,8 +615,6 @@ def map(self, func, *args, **kwargs):
619615
self : FacetGrid object
620616
621617
"""
622-
plt = import_matplotlib_pyplot()
623-
624618
for ax, namedict in zip(self.axes.flat, self.name_dicts.flat):
625619
if namedict is not None:
626620
data = self.data.loc[namedict]

xarray/plot/plot.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929
_resolve_intervals_2dplot,
3030
_update_axes,
3131
get_axis,
32-
import_matplotlib_pyplot,
3332
label_from_attrs,
3433
legend_elements,
34+
plt,
3535
)
3636

3737
# copied from seaborn
@@ -83,8 +83,6 @@ def _parse_size(data, norm, width):
8383
8484
If the data is categorical, normalize it to numbers.
8585
"""
86-
plt = import_matplotlib_pyplot()
87-
8886
if data is None:
8987
return None
9088

@@ -682,8 +680,6 @@ def scatter(
682680
**kwargs : optional
683681
Additional keyword arguments to matplotlib
684682
"""
685-
plt = import_matplotlib_pyplot()
686-
687683
# Handle facetgrids first
688684
if row or col:
689685
allargs = locals().copy()
@@ -1111,8 +1107,6 @@ def newplotfunc(
11111107
allargs["plotfunc"] = globals()[plotfunc.__name__]
11121108
return _easy_facetgrid(darray, kind="dataarray", **allargs)
11131109

1114-
plt = import_matplotlib_pyplot()
1115-
11161110
if (
11171111
plotfunc.__name__ == "surface"
11181112
and not kwargs.get("_is_facetgrid", False)

xarray/plot/utils.py

+15-19
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ def import_matplotlib_pyplot():
4747
return plt
4848

4949

50+
try:
51+
plt = import_matplotlib_pyplot()
52+
except ImportError:
53+
plt = None
54+
55+
5056
def _determine_extend(calc_data, vmin, vmax):
5157
extend_min = calc_data.min() < vmin
5258
extend_max = calc_data.max() > vmax
@@ -64,7 +70,7 @@ def _build_discrete_cmap(cmap, levels, extend, filled):
6470
"""
6571
Build a discrete colormap and normalization of the data.
6672
"""
67-
import matplotlib as mpl
73+
mpl = plt.matplotlib
6874

6975
if len(levels) == 1:
7076
levels = [levels[0], levels[0]]
@@ -115,8 +121,7 @@ def _build_discrete_cmap(cmap, levels, extend, filled):
115121

116122

117123
def _color_palette(cmap, n_colors):
118-
import matplotlib.pyplot as plt
119-
from matplotlib.colors import ListedColormap
124+
ListedColormap = plt.matplotlib.colors.ListedColormap
120125

121126
colors_i = np.linspace(0, 1.0, n_colors)
122127
if isinstance(cmap, (list, tuple)):
@@ -177,7 +182,7 @@ def _determine_cmap_params(
177182
cmap_params : dict
178183
Use depends on the type of the plotting function
179184
"""
180-
import matplotlib as mpl
185+
mpl = plt.matplotlib
181186

182187
if isinstance(levels, Iterable):
183188
levels = sorted(levels)
@@ -285,13 +290,13 @@ def _determine_cmap_params(
285290
levels = np.asarray([(vmin + vmax) / 2])
286291
else:
287292
# N in MaxNLocator refers to bins, not ticks
288-
ticker = mpl.ticker.MaxNLocator(levels - 1)
293+
ticker = plt.MaxNLocator(levels - 1)
289294
levels = ticker.tick_values(vmin, vmax)
290295
vmin, vmax = levels[0], levels[-1]
291296

292297
# GH3734
293298
if vmin == vmax:
294-
vmin, vmax = mpl.ticker.LinearLocator(2).tick_values(vmin, vmax)
299+
vmin, vmax = plt.LinearLocator(2).tick_values(vmin, vmax)
295300

296301
if extend is None:
297302
extend = _determine_extend(calc_data, vmin, vmax)
@@ -421,10 +426,7 @@ def _assert_valid_xy(darray, xy, name):
421426

422427

423428
def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):
424-
try:
425-
import matplotlib as mpl
426-
import matplotlib.pyplot as plt
427-
except ImportError:
429+
if plt is None:
428430
raise ImportError("matplotlib is required for plot.utils.get_axis")
429431

430432
if figsize is not None:
@@ -437,7 +439,7 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):
437439
if ax is not None:
438440
raise ValueError("cannot provide both `size` and `ax` arguments")
439441
if aspect is None:
440-
width, height = mpl.rcParams["figure.figsize"]
442+
width, height = plt.rcParams["figure.figsize"]
441443
aspect = width / height
442444
figsize = (size * aspect, size)
443445
_, ax = plt.subplots(figsize=figsize)
@@ -454,9 +456,6 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):
454456

455457

456458
def _maybe_gca(**kwargs):
457-
458-
import matplotlib.pyplot as plt
459-
460459
# can call gcf unconditionally: either it exists or would be created by plt.axes
461460
f = plt.gcf()
462461

@@ -912,9 +911,7 @@ def _process_cmap_cbar_kwargs(
912911

913912

914913
def _get_nice_quiver_magnitude(u, v):
915-
import matplotlib as mpl
916-
917-
ticker = mpl.ticker.MaxNLocator(3)
914+
ticker = plt.MaxNLocator(3)
918915
mean = np.mean(np.hypot(u.to_numpy(), v.to_numpy()))
919916
magnitude = ticker.tick_values(0, mean)[-2]
920917
return magnitude
@@ -989,7 +986,7 @@ def legend_elements(
989986
"""
990987
import warnings
991988

992-
import matplotlib as mpl
989+
mpl = plt.matplotlib
993990

994991
mlines = mpl.lines
995992

@@ -1126,7 +1123,6 @@ def _legend_add_subtitle(handles, labels, text, func):
11261123

11271124
def _adjust_legend_subtitles(legend):
11281125
"""Make invisible-handle "subtitles" entries look more like titles."""
1129-
plt = import_matplotlib_pyplot()
11301126

11311127
# Legend title not in rcParams until 3.0
11321128
font_size = plt.rcParams.get("legend.title_fontsize", None)

0 commit comments

Comments
 (0)