Skip to content

Commit 8a2fbb8

Browse files
cjauvinmathauseIllviljanhuardpre-commit-ci[bot]
authored
Weighted quantile (#6059)
* Add weighted quantile * Add weighted quantile to documentation * Apply suggestions from code review Co-authored-by: Mathias Hauser <[email protected]> * Apply suggestions from code review Co-authored-by: Mathias Hauser <[email protected]> * Improve _weighted_quantile_type7_1d ufunc with suggestions * Expand scope of invalid q value test * Fix weighted quantile with zero weights * Replace np.ones by xr.ones_like in weighted quantile test * Process weighted quantile data with all nans * Fix operator precedence bug * Used effective sample size. Generalize to different quantile types supporting weighted quantiles (4-9, but only 7 is exposed and tested). Fixed unit tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Co-authored-by: Mathias Hauser <[email protected]> Co-authored-by: Illviljan <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * added missing Typing hints * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update what's new and pep8 fixes * add docstring paragraph discussing weight interpretation * recognize numpy names for quantile interpolation methods * tweak to avoid warning with all nans data. simplify test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove integers from quantile interpolation available methods * remove merge artifacts * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [skip-ci] fix bad merge in whats-new * Add references * renamed htype argument to method in private functions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update xarray/core/weighted.py Co-authored-by: Abel Aoun <[email protected]> * Add skipped test to verify equal weights quantile with methods * Apply suggestions from code review Co-authored-by: Mathias Hauser <[email protected]> * Update xarray/core/weighted.py Co-authored-by: Mathias Hauser <[email protected]> * modifications suggested by review: comments, remove align, clarify test logic * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Co-authored-by: Mathias Hauser <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use broadcast * move whatsnew entry * Apply suggestions from code review * switch skipna determination * use align and broadcast Co-authored-by: Mathias Hauser <[email protected]> Co-authored-by: Illviljan <[email protected]> Co-authored-by: David Huard <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: Abel Aoun <[email protected]> Co-authored-by: Mathias Hauser <[email protected]>
1 parent 8f42bfd commit 8a2fbb8

File tree

5 files changed

+453
-20
lines changed

5 files changed

+453
-20
lines changed

doc/api.rst

+2
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,7 @@ Dataset
944944

945945
DatasetWeighted
946946
DatasetWeighted.mean
947+
DatasetWeighted.quantile
947948
DatasetWeighted.sum
948949
DatasetWeighted.std
949950
DatasetWeighted.var
@@ -958,6 +959,7 @@ DataArray
958959

959960
DataArrayWeighted
960961
DataArrayWeighted.mean
962+
DataArrayWeighted.quantile
961963
DataArrayWeighted.sum
962964
DataArrayWeighted.std
963965
DataArrayWeighted.var

doc/user-guide/computation.rst

+7-1
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ Weighted array reductions
265265

266266
:py:class:`DataArray` and :py:class:`Dataset` objects include :py:meth:`DataArray.weighted`
267267
and :py:meth:`Dataset.weighted` array reduction methods. They currently
268-
support weighted ``sum``, ``mean``, ``std`` and ``var``.
268+
support weighted ``sum``, ``mean``, ``std``, ``var`` and ``quantile``.
269269

270270
.. ipython:: python
271271
@@ -293,6 +293,12 @@ Calculate the weighted mean:
293293
294294
weighted_prec.mean(dim="month")
295295
296+
Calculate the weighted quantile:
297+
298+
.. ipython:: python
299+
300+
weighted_prec.quantile(q=0.5, dim="month")
301+
296302
The weighted sum corresponds to:
297303

298304
.. ipython:: python

doc/whats-new.rst

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ v2022.03.1 (unreleased)
2222
New Features
2323
~~~~~~~~~~~~
2424

25+
- Add a weighted ``quantile`` method to :py:class:`~core.weighted.DatasetWeighted` and
26+
:py:class:`~core.weighted.DataArrayWeighted` (:pull:`6059`). By
27+
`Christian Jauvin <https://github.com/cjauvin>`_ and `David Huard <https://github.com/huard>`_.
2528
- Add a ``create_index=True`` parameter to :py:meth:`Dataset.stack` and
2629
:py:meth:`DataArray.stack` so that the creation of multi-indexes is optional
2730
(:pull:`5692`). By `Benoît Bovy <https://github.com/benbovy>`_.

xarray/core/weighted.py

+220-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,26 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Generic, Hashable, Iterable, cast
3+
from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Literal, Sequence, cast
44

55
import numpy as np
66

7-
from . import duck_array_ops
8-
from .computation import dot
7+
from . import duck_array_ops, utils
8+
from .alignment import align, broadcast
9+
from .computation import apply_ufunc, dot
10+
from .npcompat import ArrayLike
911
from .pycompat import is_duck_dask_array
1012
from .types import T_Xarray
1113

14+
# Weighted quantile methods are a subset of the numpy supported quantile methods.
15+
QUANTILE_METHODS = Literal[
16+
"linear",
17+
"interpolated_inverted_cdf",
18+
"hazen",
19+
"weibull",
20+
"median_unbiased",
21+
"normal_unbiased",
22+
]
23+
1224
_WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """
1325
Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s).
1426
@@ -56,6 +68,61 @@
5668
New {cls} object with the sum of the weights over the given dimension.
5769
"""
5870

71+
_WEIGHTED_QUANTILE_DOCSTRING_TEMPLATE = """
72+
Apply a weighted ``quantile`` to this {cls}'s data along some dimension(s).
73+
74+
Weights are interpreted as *sampling weights* (or probability weights) and
75+
describe how a sample is scaled to the whole population [1]_. There are
76+
other possible interpretations for weights, *precision weights* describing the
77+
precision of observations, or *frequency weights* counting the number of identical
78+
observations, however, they are not implemented here.
79+
80+
For compatibility with NumPy's non-weighted ``quantile`` (which is used by
81+
``DataArray.quantile`` and ``Dataset.quantile``), the only interpolation
82+
method supported by this weighted version corresponds to the default "linear"
83+
option of ``numpy.quantile``. This is "Type 7" option, described in Hyndman
84+
and Fan (1996) [2]_. The implementation is largely inspired by a blog post
85+
from A. Akinshin's [3]_.
86+
87+
Parameters
88+
----------
89+
q : float or sequence of float
90+
Quantile to compute, which must be between 0 and 1 inclusive.
91+
dim : str or sequence of str, optional
92+
Dimension(s) over which to apply the weighted ``quantile``.
93+
skipna : bool, optional
94+
If True, skip missing values (as marked by NaN). By default, only
95+
skips missing values for float dtypes; other dtypes either do not
96+
have a sentinel missing value (int) or skipna=True has not been
97+
implemented (object, datetime64 or timedelta64).
98+
keep_attrs : bool, optional
99+
If True, the attributes (``attrs``) will be copied from the original
100+
object to the new one. If False (default), the new object will be
101+
returned without attributes.
102+
103+
Returns
104+
-------
105+
quantiles : {cls}
106+
New {cls} object with weighted ``quantile`` applied to its data and
107+
the indicated dimension(s) removed.
108+
109+
See Also
110+
--------
111+
numpy.nanquantile, pandas.Series.quantile, Dataset.quantile, DataArray.quantile
112+
113+
Notes
114+
-----
115+
Returns NaN if the ``weights`` sum to 0.0 along the reduced
116+
dimension(s).
117+
118+
References
119+
----------
120+
.. [1] https://notstatschat.rbind.io/2020/08/04/weights-in-statistics/
121+
.. [2] Hyndman, R. J. & Fan, Y. (1996). Sample Quantiles in Statistical Packages.
122+
The American Statistician, 50(4), 361–365. https://doi.org/10.2307/2684934
123+
.. [3] https://aakinshin.net/posts/weighted-quantiles
124+
"""
125+
59126

60127
if TYPE_CHECKING:
61128
from .dataarray import DataArray
@@ -241,6 +308,141 @@ def _weighted_std(
241308

242309
return cast("DataArray", np.sqrt(self._weighted_var(da, dim, skipna)))
243310

311+
def _weighted_quantile(
312+
self,
313+
da: DataArray,
314+
q: ArrayLike,
315+
dim: Hashable | Iterable[Hashable] | None = None,
316+
skipna: bool = None,
317+
) -> DataArray:
318+
"""Apply a weighted ``quantile`` to a DataArray along some dimension(s)."""
319+
320+
def _get_h(n: float, q: np.ndarray, method: QUANTILE_METHODS) -> np.ndarray:
321+
"""Return the interpolation parameter."""
322+
# Note that options are not yet exposed in the public API.
323+
if method == "linear":
324+
h = (n - 1) * q + 1
325+
elif method == "interpolated_inverted_cdf":
326+
h = n * q
327+
elif method == "hazen":
328+
h = n * q + 0.5
329+
elif method == "weibull":
330+
h = (n + 1) * q
331+
elif method == "median_unbiased":
332+
h = (n + 1 / 3) * q + 1 / 3
333+
elif method == "normal_unbiased":
334+
h = (n + 1 / 4) * q + 3 / 8
335+
else:
336+
raise ValueError(f"Invalid method: {method}.")
337+
return h.clip(1, n)
338+
339+
def _weighted_quantile_1d(
340+
data: np.ndarray,
341+
weights: np.ndarray,
342+
q: np.ndarray,
343+
skipna: bool,
344+
method: QUANTILE_METHODS = "linear",
345+
) -> np.ndarray:
346+
347+
# This algorithm has been adapted from:
348+
# https://aakinshin.net/posts/weighted-quantiles/#reference-implementation
349+
is_nan = np.isnan(data)
350+
if skipna:
351+
# Remove nans from data and weights
352+
not_nan = ~is_nan
353+
data = data[not_nan]
354+
weights = weights[not_nan]
355+
elif is_nan.any():
356+
# Return nan if data contains any nan
357+
return np.full(q.size, np.nan)
358+
359+
# Filter out data (and weights) associated with zero weights, which also flattens them
360+
nonzero_weights = weights != 0
361+
data = data[nonzero_weights]
362+
weights = weights[nonzero_weights]
363+
n = data.size
364+
365+
if n == 0:
366+
# Possibly empty after nan or zero weight filtering above
367+
return np.full(q.size, np.nan)
368+
369+
# Kish's effective sample size
370+
nw = weights.sum() ** 2 / (weights**2).sum()
371+
372+
# Sort data and weights
373+
sorter = np.argsort(data)
374+
data = data[sorter]
375+
weights = weights[sorter]
376+
377+
# Normalize and sum the weights
378+
weights = weights / weights.sum()
379+
weights_cum = np.append(0, weights.cumsum())
380+
381+
# Vectorize the computation by transposing q with respect to weights
382+
q = np.atleast_2d(q).T
383+
384+
# Get the interpolation parameter for each q
385+
h = _get_h(nw, q, method)
386+
387+
# Find the samples contributing to the quantile computation (at *positions* between (h-1)/nw and h/nw)
388+
u = np.maximum((h - 1) / nw, np.minimum(h / nw, weights_cum))
389+
390+
# Compute their relative weight
391+
v = u * nw - h + 1
392+
w = np.diff(v)
393+
394+
# Apply the weights
395+
return (data * w).sum(axis=1)
396+
397+
if skipna is None and da.dtype.kind in "cfO":
398+
skipna = True
399+
400+
q = np.atleast_1d(np.asarray(q, dtype=np.float64))
401+
402+
if q.ndim > 1:
403+
raise ValueError("q must be a scalar or 1d")
404+
405+
if np.any((q < 0) | (q > 1)):
406+
raise ValueError("q values must be between 0 and 1")
407+
408+
if dim is None:
409+
dim = da.dims
410+
411+
if utils.is_scalar(dim):
412+
dim = [dim]
413+
414+
# To satisfy mypy
415+
dim = cast(Sequence, dim)
416+
417+
# need to align *and* broadcast
418+
# - `_weighted_quantile_1d` requires arrays with the same shape
419+
# - broadcast does an outer join, which can introduce NaN to weights
420+
# - therefore we first need to do align(..., join="inner")
421+
422+
# TODO: use broadcast(..., join="inner") once available
423+
# see https://github.com/pydata/xarray/issues/6304
424+
425+
da, weights = align(da, self.weights, join="inner")
426+
da, weights = broadcast(da, weights)
427+
428+
result = apply_ufunc(
429+
_weighted_quantile_1d,
430+
da,
431+
weights,
432+
input_core_dims=[dim, dim],
433+
output_core_dims=[["quantile"]],
434+
output_dtypes=[np.float64],
435+
dask_gufunc_kwargs=dict(output_sizes={"quantile": len(q)}),
436+
dask="parallelized",
437+
vectorize=True,
438+
kwargs={"q": q, "skipna": skipna},
439+
)
440+
441+
result = result.transpose("quantile", ...)
442+
result = result.assign_coords(quantile=q).squeeze()
443+
444+
return result
445+
244446
def _implementation(self, func, dim, **kwargs):
245447

246448
raise NotImplementedError("Use `Dataset.weighted` or `DataArray.weighted`")
@@ -310,6 +512,19 @@ def std(
310512
self._weighted_std, dim=dim, skipna=skipna, keep_attrs=keep_attrs
311513
)
312514

515+
def quantile(
516+
self,
517+
q: ArrayLike,
518+
*,
519+
dim: Hashable | Sequence[Hashable] | None = None,
520+
keep_attrs: bool = None,
521+
skipna: bool = True,
522+
) -> T_Xarray:
523+
524+
return self._implementation(
525+
self._weighted_quantile, q=q, dim=dim, skipna=skipna, keep_attrs=keep_attrs
526+
)
527+
313528
def __repr__(self):
314529
"""provide a nice str repr of our Weighted object"""
315530

@@ -360,6 +575,8 @@ def _inject_docstring(cls, cls_name):
360575
cls=cls_name, fcn="std", on_zero="NaN"
361576
)
362577

578+
cls.quantile.__doc__ = _WEIGHTED_QUANTILE_DOCSTRING_TEMPLATE.format(cls=cls_name)
579+
363580

364581
_inject_docstring(DataArrayWeighted, "DataArray")
365582
_inject_docstring(DatasetWeighted, "Dataset")

0 commit comments

Comments
 (0)