Skip to content

Commit 2694046

Browse files
jaicherIllviljan
andauthored
Revert "Single matplotlib import (#5794)" (#6064)
This reverts commit ea28861. Co-authored-by: Illviljan <[email protected]>
1 parent 379b5b7 commit 2694046

File tree

5 files changed

+42
-31
lines changed

5 files changed

+42
-31
lines changed

asv_bench/benchmarks/import_xarray.py

-9
This file was deleted.

xarray/plot/dataset_plot.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
_process_cmap_cbar_kwargs,
1313
get_axis,
1414
label_from_attrs,
15-
plt,
1615
)
1716

1817
# copied from seaborn
@@ -135,7 +134,8 @@ def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None)
135134

136135
# copied from seaborn
137136
def _parse_size(data, norm):
138-
mpl = plt.matplotlib
137+
138+
import matplotlib as mpl
139139

140140
if data is None:
141141
return None
@@ -544,6 +544,8 @@ def quiver(ds, x, y, ax, u, v, **kwargs):
544544
545545
Wraps :py:func:`matplotlib:matplotlib.pyplot.quiver`.
546546
"""
547+
import matplotlib as mpl
548+
547549
if x is None or y is None or u is None or v is None:
548550
raise ValueError("Must specify x, y, u, v for quiver plots.")
549551

@@ -558,7 +560,7 @@ def quiver(ds, x, y, ax, u, v, **kwargs):
558560

559561
# TODO: Fix this by always returning a norm with vmin, vmax in cmap_params
560562
if not cmap_params["norm"]:
561-
cmap_params["norm"] = plt.Normalize(
563+
cmap_params["norm"] = mpl.colors.Normalize(
562564
cmap_params.pop("vmin"), cmap_params.pop("vmax")
563565
)
564566

@@ -574,6 +576,8 @@ def streamplot(ds, x, y, ax, u, v, **kwargs):
574576
575577
Wraps :py:func:`matplotlib:matplotlib.pyplot.streamplot`.
576578
"""
579+
import matplotlib as mpl
580+
577581
if x is None or y is None or u is None or v is None:
578582
raise ValueError("Must specify x, y, u, v for streamplot plots.")
579583

@@ -609,7 +613,7 @@ def streamplot(ds, x, y, ax, u, v, **kwargs):
609613

610614
# TODO: Fix this by always returning a norm with vmin, vmax in cmap_params
611615
if not cmap_params["norm"]:
612-
cmap_params["norm"] = plt.Normalize(
616+
cmap_params["norm"] = mpl.colors.Normalize(
613617
cmap_params.pop("vmin"), cmap_params.pop("vmax")
614618
)
615619

xarray/plot/facetgrid.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
_get_nice_quiver_magnitude,
1010
_infer_xy_labels,
1111
_process_cmap_cbar_kwargs,
12+
import_matplotlib_pyplot,
1213
label_from_attrs,
13-
plt,
1414
)
1515

1616
# Overrides axes.labelsize, xtick.major.size, ytick.major.size
@@ -116,6 +116,8 @@ def __init__(
116116
117117
"""
118118

119+
plt = import_matplotlib_pyplot()
120+
119121
# Handle corner case of nonunique coordinates
120122
rep_col = col is not None and not data[col].to_index().is_unique
121123
rep_row = row is not None and not data[row].to_index().is_unique
@@ -517,8 +519,10 @@ def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwar
517519
self: FacetGrid object
518520
519521
"""
522+
import matplotlib as mpl
523+
520524
if size is None:
521-
size = plt.rcParams["axes.labelsize"]
525+
size = mpl.rcParams["axes.labelsize"]
522526

523527
nicetitle = functools.partial(_nicetitle, maxchar=maxchar, template=template)
524528

@@ -615,6 +619,8 @@ def map(self, func, *args, **kwargs):
615619
self : FacetGrid object
616620
617621
"""
622+
plt = import_matplotlib_pyplot()
623+
618624
for ax, namedict in zip(self.axes.flat, self.name_dicts.flat):
619625
if namedict is not None:
620626
data = self.data.loc[namedict]

xarray/plot/plot.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929
_resolve_intervals_2dplot,
3030
_update_axes,
3131
get_axis,
32+
import_matplotlib_pyplot,
3233
label_from_attrs,
3334
legend_elements,
34-
plt,
3535
)
3636

3737
# copied from seaborn
@@ -83,6 +83,8 @@ def _parse_size(data, norm, width):
8383
8484
If the data is categorical, normalize it to numbers.
8585
"""
86+
plt = import_matplotlib_pyplot()
87+
8688
if data is None:
8789
return None
8890

@@ -680,6 +682,8 @@ def scatter(
680682
**kwargs : optional
681683
Additional keyword arguments to matplotlib
682684
"""
685+
plt = import_matplotlib_pyplot()
686+
683687
# Handle facetgrids first
684688
if row or col:
685689
allargs = locals().copy()
@@ -1107,6 +1111,8 @@ def newplotfunc(
11071111
allargs["plotfunc"] = globals()[plotfunc.__name__]
11081112
return _easy_facetgrid(darray, kind="dataarray", **allargs)
11091113

1114+
plt = import_matplotlib_pyplot()
1115+
11101116
if (
11111117
plotfunc.__name__ == "surface"
11121118
and not kwargs.get("_is_facetgrid", False)

xarray/plot/utils.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,6 @@ def import_matplotlib_pyplot():
4747
return plt
4848

4949

50-
try:
51-
plt = import_matplotlib_pyplot()
52-
except ImportError:
53-
plt = None
54-
55-
5650
def _determine_extend(calc_data, vmin, vmax):
5751
extend_min = calc_data.min() < vmin
5852
extend_max = calc_data.max() > vmax
@@ -70,7 +64,7 @@ def _build_discrete_cmap(cmap, levels, extend, filled):
7064
"""
7165
Build a discrete colormap and normalization of the data.
7266
"""
73-
mpl = plt.matplotlib
67+
import matplotlib as mpl
7468

7569
if len(levels) == 1:
7670
levels = [levels[0], levels[0]]
@@ -121,7 +115,8 @@ def _build_discrete_cmap(cmap, levels, extend, filled):
121115

122116

123117
def _color_palette(cmap, n_colors):
124-
ListedColormap = plt.matplotlib.colors.ListedColormap
118+
import matplotlib.pyplot as plt
119+
from matplotlib.colors import ListedColormap
125120

126121
colors_i = np.linspace(0, 1.0, n_colors)
127122
if isinstance(cmap, (list, tuple)):
@@ -182,7 +177,7 @@ def _determine_cmap_params(
182177
cmap_params : dict
183178
Use depends on the type of the plotting function
184179
"""
185-
mpl = plt.matplotlib
180+
import matplotlib as mpl
186181

187182
if isinstance(levels, Iterable):
188183
levels = sorted(levels)
@@ -290,13 +285,13 @@ def _determine_cmap_params(
290285
levels = np.asarray([(vmin + vmax) / 2])
291286
else:
292287
# N in MaxNLocator refers to bins, not ticks
293-
ticker = plt.MaxNLocator(levels - 1)
288+
ticker = mpl.ticker.MaxNLocator(levels - 1)
294289
levels = ticker.tick_values(vmin, vmax)
295290
vmin, vmax = levels[0], levels[-1]
296291

297292
# GH3734
298293
if vmin == vmax:
299-
vmin, vmax = plt.LinearLocator(2).tick_values(vmin, vmax)
294+
vmin, vmax = mpl.ticker.LinearLocator(2).tick_values(vmin, vmax)
300295

301296
if extend is None:
302297
extend = _determine_extend(calc_data, vmin, vmax)
@@ -426,7 +421,10 @@ def _assert_valid_xy(darray, xy, name):
426421

427422

428423
def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):
429-
if plt is None:
424+
try:
425+
import matplotlib as mpl
426+
import matplotlib.pyplot as plt
427+
except ImportError:
430428
raise ImportError("matplotlib is required for plot.utils.get_axis")
431429

432430
if figsize is not None:
@@ -439,7 +437,7 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):
439437
if ax is not None:
440438
raise ValueError("cannot provide both `size` and `ax` arguments")
441439
if aspect is None:
442-
width, height = plt.rcParams["figure.figsize"]
440+
width, height = mpl.rcParams["figure.figsize"]
443441
aspect = width / height
444442
figsize = (size * aspect, size)
445443
_, ax = plt.subplots(figsize=figsize)
@@ -456,6 +454,9 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):
456454

457455

458456
def _maybe_gca(**kwargs):
457+
458+
import matplotlib.pyplot as plt
459+
459460
# can call gcf unconditionally: either it exists or would be created by plt.axes
460461
f = plt.gcf()
461462

@@ -913,7 +914,9 @@ def _process_cmap_cbar_kwargs(
913914

914915

915916
def _get_nice_quiver_magnitude(u, v):
916-
ticker = plt.MaxNLocator(3)
917+
import matplotlib as mpl
918+
919+
ticker = mpl.ticker.MaxNLocator(3)
917920
mean = np.mean(np.hypot(u.to_numpy(), v.to_numpy()))
918921
magnitude = ticker.tick_values(0, mean)[-2]
919922
return magnitude
@@ -988,7 +991,7 @@ def legend_elements(
988991
"""
989992
import warnings
990993

991-
mpl = plt.matplotlib
994+
import matplotlib as mpl
992995

993996
mlines = mpl.lines
994997

@@ -1125,6 +1128,7 @@ def _legend_add_subtitle(handles, labels, text, func):
11251128

11261129
def _adjust_legend_subtitles(legend):
11271130
"""Make invisible-handle "subtitles" entries look more like titles."""
1131+
plt = import_matplotlib_pyplot()
11281132

11291133
# Legend title not in rcParams until 3.0
11301134
font_size = plt.rcParams.get("legend.title_fontsize", None)

0 commit comments

Comments
 (0)