Skip to content

Commit 0aa0ae4

Browse files
Fix some scatter plot issues (#7167)
* User markersize for scatter plots. * fix .values_unique not returning same values as .values * fix typing issues in _Normalize * fix typing issues with None with hueplt_norms * fix typing issues with None hueplt_norm * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 24d038f commit 0aa0ae4

File tree

3 files changed

+135
-87
lines changed

3 files changed

+135
-87
lines changed

xarray/plot/dataarray_plot.py

+24-14
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Iterable,
1111
Literal,
1212
MutableMapping,
13+
cast,
1314
overload,
1415
)
1516

@@ -925,7 +926,7 @@ def newplotfunc(
925926

926927
_is_facetgrid = kwargs.pop("_is_facetgrid", False)
927928

928-
if markersize is not None:
929+
if plotfunc.__name__ == "scatter":
929930
size_ = markersize
930931
size_r = _MARKERSIZE_RANGE
931932
else:
@@ -960,7 +961,7 @@ def newplotfunc(
960961

961962
cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
962963
plotfunc,
963-
hueplt_norm.values.data,
964+
cast("DataArray", hueplt_norm.values).data,
964965
**locals(),
965966
)
966967

@@ -1013,13 +1014,7 @@ def newplotfunc(
10131014
)
10141015

10151016
if add_legend_:
1016-
if plotfunc.__name__ == "hist":
1017-
ax.legend(
1018-
handles=primitive[-1],
1019-
labels=list(hueplt_norm.values.to_numpy()),
1020-
title=label_from_attrs(hueplt_norm.data),
1021-
)
1022-
elif plotfunc.__name__ in ["scatter", "line"]:
1017+
if plotfunc.__name__ in ["scatter", "line"]:
10231018
_add_legend(
10241019
hueplt_norm
10251020
if add_legend or not add_colorbar_
@@ -1030,11 +1025,26 @@ def newplotfunc(
10301025
plotfunc=plotfunc.__name__,
10311026
)
10321027
else:
1033-
ax.legend(
1034-
handles=primitive,
1035-
labels=list(hueplt_norm.values.to_numpy()),
1036-
title=label_from_attrs(hueplt_norm.data),
1037-
)
1028+
hueplt_norm_values: list[np.ndarray | None]
1029+
if hueplt_norm.data is not None:
1030+
hueplt_norm_values = list(
1031+
cast("DataArray", hueplt_norm.data).to_numpy()
1032+
)
1033+
else:
1034+
hueplt_norm_values = [hueplt_norm.data]
1035+
1036+
if plotfunc.__name__ == "hist":
1037+
ax.legend(
1038+
handles=primitive[-1],
1039+
labels=hueplt_norm_values,
1040+
title=label_from_attrs(hueplt_norm.data),
1041+
)
1042+
else:
1043+
ax.legend(
1044+
handles=primitive,
1045+
labels=hueplt_norm_values,
1046+
title=label_from_attrs(hueplt_norm.data),
1047+
)
10381048

10391049
_update_axes(
10401050
ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim

xarray/plot/facetgrid.py

+25-12
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Iterable,
1313
Literal,
1414
TypeVar,
15+
cast,
1516
)
1617

1718
import numpy as np
@@ -41,6 +42,9 @@
4142
from matplotlib.quiver import QuiverKey
4243
from matplotlib.text import Annotation
4344

45+
from ..core.dataarray import DataArray
46+
47+
4448
# Overrides axes.labelsize, xtick.major.size, ytick.major.size
4549
# from mpl.rcParams
4650
_FONTSIZE = "small"
@@ -402,18 +406,24 @@ def map_plot1d(
402406
hueplt_norm = _Normalize(hueplt)
403407
self._hue_var = hueplt
404408
cbar_kwargs = kwargs.pop("cbar_kwargs", {})
405-
if not hueplt_norm.data_is_numeric:
406-
# TODO: Ticks seems a little too hardcoded, since it will always
407-
# show all the values. But maybe it's ok, since plotting hundreds
408-
# of categorical data isn't that meaningful anyway.
409-
cbar_kwargs.update(format=hueplt_norm.format, ticks=hueplt_norm.ticks)
410-
kwargs.update(levels=hueplt_norm.levels)
411-
if "label" not in cbar_kwargs:
412-
cbar_kwargs["label"] = label_from_attrs(hueplt_norm.data)
413-
cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
414-
func, hueplt_norm.values.to_numpy(), cbar_kwargs=cbar_kwargs, **kwargs
415-
)
416-
self._cmap_extend = cmap_params.get("extend")
409+
410+
if hueplt_norm.data is not None:
411+
if not hueplt_norm.data_is_numeric:
412+
# TODO: Ticks seems a little too hardcoded, since it will always
413+
# show all the values. But maybe it's ok, since plotting hundreds
414+
# of categorical data isn't that meaningful anyway.
415+
cbar_kwargs.update(format=hueplt_norm.format, ticks=hueplt_norm.ticks)
416+
kwargs.update(levels=hueplt_norm.levels)
417+
418+
cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
419+
func,
420+
cast("DataArray", hueplt_norm.values).data,
421+
cbar_kwargs=cbar_kwargs,
422+
**kwargs,
423+
)
424+
self._cmap_extend = cmap_params.get("extend")
425+
else:
426+
cmap_params = {}
417427

418428
# Handle sizes:
419429
_size_r = _MARKERSIZE_RANGE if func.__name__ == "scatter" else _LINEWIDTH_RANGE
@@ -513,6 +523,9 @@ def map_plot1d(
513523

514524
if add_colorbar:
515525
# Colorbar is after legend so it correctly fits the plot:
526+
if "label" not in cbar_kwargs:
527+
cbar_kwargs["label"] = label_from_attrs(hueplt_norm.data)
528+
516529
self.add_colorbar(**cbar_kwargs)
517530

518531
return self

0 commit comments

Comments
 (0)