19
19
_assert_valid_xy ,
20
20
_determine_guide ,
21
21
_ensure_plottable ,
22
+ _guess_coords_to_plot ,
22
23
_infer_interval_breaks ,
23
24
_infer_xy_labels ,
24
25
_Normalize ,
@@ -142,48 +143,45 @@ def _infer_line_data(
142
143
return xplt , yplt , hueplt , huelabel
143
144
144
145
145
- def _infer_plot_dims (
146
- darray : DataArray ,
147
- dims_plot : MutableMapping [str , Hashable ],
148
- default_guess : Iterable [str ] = ("x" , "hue" , "size" ),
149
- ) -> MutableMapping [str , Hashable ]:
146
+ def _prepare_plot1d_data (
147
+ darray : T_DataArray ,
148
+ coords_to_plot : MutableMapping [str , Hashable ],
149
+ plotfunc_name : str | None = None ,
150
+ _is_facetgrid : bool = False ,
151
+ ) -> dict [str , T_DataArray ]:
150
152
"""
151
- Guess what dims to plot if some of the values in dims_plot are None which
152
- happens when the user has not defined all available ways of visualizing
153
- the data.
153
+ Prepare data for usage with plt.scatter.
154
154
155
155
Parameters
156
156
----------
157
- darray : DataArray
158
- The DataArray to check.
159
- dims_plot : T_DimsPlot
160
- Dims defined by the user to plot.
161
- default_guess : Iterable[str], optional
162
- Default values and order to retrieve dims if values in dims_plot is
163
- missing, default: ("x", "hue", "size").
164
- """
165
- dims_plot_exist = {k : v for k , v in dims_plot .items () if v is not None }
166
- dims_avail = tuple (v for v in darray .dims if v not in dims_plot_exist .values ())
167
-
168
- # If dims_plot[k] isn't defined then fill with one of the available dims:
169
- for k , v in zip (default_guess , dims_avail ):
170
- if dims_plot .get (k , None ) is None :
171
- dims_plot [k ] = v
172
-
173
- for k , v in dims_plot .items ():
174
- _assert_valid_xy (darray , v , k )
175
-
176
- return dims_plot
177
-
157
+ darray : T_DataArray
158
+ Base DataArray.
159
+ coords_to_plot : MutableMapping[str, Hashable]
160
+ Coords that will be plotted.
161
+ plotfunc_name : str | None
162
+ Name of the plotting function that will be used.
178
163
179
- def _infer_line_data2 (
180
- darray : T_DataArray ,
181
- dims_plot : MutableMapping [str , Hashable ],
182
- plotfunc_name : None | str = None ,
183
- ) -> dict [str , T_DataArray ]:
184
- # Guess what dims to use if some of the values in plot_dims are None:
185
- dims_plot = _infer_plot_dims (darray , dims_plot )
164
+ Returns
165
+ -------
166
+ plts : dict[str, T_DataArray]
167
+ Dict of DataArrays that will be sent to matplotlib.
186
168
169
+ Examples
170
+ --------
171
+ >>> # Make sure int coords are plotted:
172
+ >>> a = xr.DataArray(
173
+ ... data=[1, 2],
174
+ ... coords={1: ("x", [0, 1], {"units": "s"})},
175
+ ... dims=("x",),
176
+ ... name="a",
177
+ ... )
178
+ >>> plts = xr.plot.dataarray_plot._prepare_plot1d_data(
179
+ ... a, coords_to_plot={"x": 1, "z": None, "hue": None, "size": None}
180
+ ... )
181
+ >>> # Check which coords to plot:
182
+ >>> print({k: v.name for k, v in plts.items()})
183
+ {'y': 'a', 'x': 1}
184
+ """
187
185
# If there are more than 1 dimension in the array than stack all the
188
186
# dimensions so the plotter can plot anything:
189
187
if darray .ndim > 1 :
@@ -193,11 +191,11 @@ def _infer_line_data2(
193
191
dims_T = []
194
192
if np .issubdtype (darray .dtype , np .floating ):
195
193
for v in ["z" , "x" ]:
196
- dim = dims_plot .get (v , None )
194
+ dim = coords_to_plot .get (v , None )
197
195
if (dim is not None ) and (dim in darray .dims ):
198
196
darray_nan = np .nan * darray .isel ({dim : - 1 })
199
197
darray = concat ([darray , darray_nan ], dim = dim )
200
- dims_T .append (dims_plot [v ])
198
+ dims_T .append (coords_to_plot [v ])
201
199
202
200
# Lines should never connect to the same coordinate when stacked,
203
201
# transpose to avoid this as much as possible:
@@ -207,11 +205,13 @@ def _infer_line_data2(
207
205
darray = darray .stack (_stacked_dim = darray .dims )
208
206
209
207
# Broadcast together all the chosen variables:
210
- out = dict (y = darray )
211
- out .update ({k : darray [v ] for k , v in dims_plot .items () if v is not None })
212
- out = dict (zip (out .keys (), broadcast (* (out .values ()))))
208
+ plts = dict (y = darray )
209
+ plts .update (
210
+ {k : darray .coords [v ] for k , v in coords_to_plot .items () if v is not None }
211
+ )
212
+ plts = dict (zip (plts .keys (), broadcast (* (plts .values ()))))
213
213
214
- return out
214
+ return plts
215
215
216
216
217
217
# return type is Any due to the many different possibilities
@@ -938,15 +938,20 @@ def newplotfunc(
938
938
_is_facetgrid = kwargs .pop ("_is_facetgrid" , False )
939
939
940
940
if plotfunc .__name__ == "scatter" :
941
- size_ = markersize
941
+ size_ = kwargs . pop ( "_size" , markersize )
942
942
size_r = _MARKERSIZE_RANGE
943
943
else :
944
- size_ = linewidth
944
+ size_ = kwargs . pop ( "_size" , linewidth )
945
945
size_r = _LINEWIDTH_RANGE
946
946
947
947
# Get data to plot:
948
- dims_plot = dict (x = x , z = z , hue = hue , size = size_ )
949
- plts = _infer_line_data2 (darray , dims_plot , plotfunc .__name__ )
948
+ coords_to_plot : MutableMapping [str , Hashable | None ] = dict (
949
+ x = x , z = z , hue = hue , size = size_
950
+ )
951
+ if not _is_facetgrid :
952
+ # Guess what coords to use if some of the values in coords_to_plot are None:
953
+ coords_to_plot = _guess_coords_to_plot (darray , coords_to_plot , kwargs )
954
+ plts = _prepare_plot1d_data (darray , coords_to_plot , plotfunc .__name__ )
950
955
xplt = plts .pop ("x" , None )
951
956
yplt = plts .pop ("y" , None )
952
957
zplt = plts .pop ("z" , None )
0 commit comments