Skip to content

Commit 562f2f8

Browse files
llurituLlorenc Lledopre-commit-ci[bot]max-sixty
authored
Added option to specify weights in xr.corr() and xr.cov() (#8527)
* Added function _weighted_cov_corr and modified cov and corr to call it if parameter weights is not None * Correct two indentation errors * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Stupid typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove the min_count argument from mean * Unified the code for weighted and unweighted _cov_corr * Remove old _cov_corr function after checking that new version produces same results when weights=None or weights=xr.DataArray(1) * Added examples that use weights for cov and corr * Added two tests for weighted correlation and covariance * Fix error in mypy, allow None as weights type. * Update xarray/core/computation.py Co-authored-by: Maximilian Roos <[email protected]> * Update xarray/core/computation.py Co-authored-by: Maximilian Roos <[email protected]> * Info on new options for cov and corr in whatsnew * Info on new options for cov and corr in whatsnew * Fix typing --------- Co-authored-by: Llorenc Lledo <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Maximilian Roos <[email protected]> Co-authored-by: Maximilian Roos <[email protected]>
1 parent 8ad0b83 commit 562f2f8

File tree

3 files changed

+181
-18
lines changed

3 files changed

+181
-18
lines changed

doc/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ v2023.12.1 (unreleased)
2424
New Features
2525
~~~~~~~~~~~~
2626

27+
- :py:meth:`xr.cov` and :py:meth:`xr.corr` now support using weights (:issue:`8527`, :pull:`7392`).
28+
By `Llorenç Lledó <https://github.com/lluritu>`_.
2729

2830
Breaking changes
2931
~~~~~~~~~~~~~~~~

xarray/core/computation.py

+88-18
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import warnings
1010
from collections import Counter
1111
from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence, Set
12-
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, overload
12+
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, cast, overload
1313

1414
import numpy as np
1515

@@ -1281,7 +1281,11 @@ def apply_ufunc(
12811281

12821282

12831283
def cov(
1284-
da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None, ddof: int = 1
1284+
da_a: T_DataArray,
1285+
da_b: T_DataArray,
1286+
dim: Dims = None,
1287+
ddof: int = 1,
1288+
weights: T_DataArray | None = None,
12851289
) -> T_DataArray:
12861290
"""
12871291
Compute covariance between two DataArray objects along a shared dimension.
@@ -1297,6 +1301,8 @@ def cov(
12971301
ddof : int, default: 1
12981302
If ddof=1, covariance is normalized by N-1, giving an unbiased estimate,
12991303
else normalization is by N.
1304+
weights : DataArray, optional
1305+
Array of weights.
13001306
13011307
Returns
13021308
-------
@@ -1350,6 +1356,23 @@ def cov(
13501356
array([ 0.2 , -0.5 , 1.69333333])
13511357
Coordinates:
13521358
* space (space) <U2 'IA' 'IL' 'IN'
1359+
>>> weights = DataArray(
1360+
... [4, 2, 1],
1361+
... dims=("space"),
1362+
... coords=[
1363+
... ("space", ["IA", "IL", "IN"]),
1364+
... ],
1365+
... )
1366+
>>> weights
1367+
<xarray.DataArray (space: 3)>
1368+
array([4, 2, 1])
1369+
Coordinates:
1370+
* space (space) <U2 'IA' 'IL' 'IN'
1371+
>>> xr.cov(da_a, da_b, dim="space", weights=weights)
1372+
<xarray.DataArray (time: 3)>
1373+
array([-4.69346939, -4.49632653, -3.37959184])
1374+
Coordinates:
1375+
* time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03
13531376
"""
13541377
from xarray.core.dataarray import DataArray
13551378

@@ -1358,11 +1381,18 @@ def cov(
13581381
"Only xr.DataArray is supported."
13591382
f"Given {[type(arr) for arr in [da_a, da_b]]}."
13601383
)
1384+
if weights is not None:
1385+
if not isinstance(weights, DataArray):
1386+
raise TypeError("Only xr.DataArray is supported." f"Given {type(weights)}.")
1387+
return _cov_corr(da_a, da_b, weights=weights, dim=dim, ddof=ddof, method="cov")
13611388

1362-
return _cov_corr(da_a, da_b, dim=dim, ddof=ddof, method="cov")
13631389

1364-
1365-
def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray:
1390+
def corr(
1391+
da_a: T_DataArray,
1392+
da_b: T_DataArray,
1393+
dim: Dims = None,
1394+
weights: T_DataArray | None = None,
1395+
) -> T_DataArray:
13661396
"""
13671397
Compute the Pearson correlation coefficient between
13681398
two DataArray objects along a shared dimension.
@@ -1375,6 +1405,8 @@ def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray:
13751405
Array to compute.
13761406
dim : str, iterable of hashable, "..." or None, optional
13771407
The dimension along which the correlation will be computed
1408+
weights : DataArray, optional
1409+
Array of weights.
13781410
13791411
Returns
13801412
-------
@@ -1428,6 +1460,23 @@ def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray:
14281460
array([ 1., -1., 1.])
14291461
Coordinates:
14301462
* space (space) <U2 'IA' 'IL' 'IN'
1463+
>>> weights = DataArray(
1464+
... [4, 2, 1],
1465+
... dims=("space"),
1466+
... coords=[
1467+
... ("space", ["IA", "IL", "IN"]),
1468+
... ],
1469+
... )
1470+
>>> weights
1471+
<xarray.DataArray (space: 3)>
1472+
array([4, 2, 1])
1473+
Coordinates:
1474+
* space (space) <U2 'IA' 'IL' 'IN'
1475+
>>> xr.corr(da_a, da_b, dim="space", weights=weights)
1476+
<xarray.DataArray (time: 3)>
1477+
array([-0.50240504, -0.83215028, -0.99057446])
1478+
Coordinates:
1479+
* time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03
14311480
"""
14321481
from xarray.core.dataarray import DataArray
14331482

@@ -1436,13 +1485,16 @@ def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray:
14361485
"Only xr.DataArray is supported."
14371486
f"Given {[type(arr) for arr in [da_a, da_b]]}."
14381487
)
1439-
1440-
return _cov_corr(da_a, da_b, dim=dim, method="corr")
1488+
if weights is not None:
1489+
if not isinstance(weights, DataArray):
1490+
raise TypeError("Only xr.DataArray is supported." f"Given {type(weights)}.")
1491+
return _cov_corr(da_a, da_b, weights=weights, dim=dim, method="corr")
14411492

14421493

14431494
def _cov_corr(
14441495
da_a: T_DataArray,
14451496
da_b: T_DataArray,
1497+
weights: T_DataArray | None = None,
14461498
dim: Dims = None,
14471499
ddof: int = 0,
14481500
method: Literal["cov", "corr", None] = None,
@@ -1458,28 +1510,46 @@ def _cov_corr(
14581510
valid_values = da_a.notnull() & da_b.notnull()
14591511
da_a = da_a.where(valid_values)
14601512
da_b = da_b.where(valid_values)
1461-
valid_count = valid_values.sum(dim) - ddof
14621513

14631514
# 3. Detrend along the given dim
1464-
demeaned_da_a = da_a - da_a.mean(dim=dim)
1465-
demeaned_da_b = da_b - da_b.mean(dim=dim)
1515+
if weights is not None:
1516+
demeaned_da_a = da_a - da_a.weighted(weights).mean(dim=dim)
1517+
demeaned_da_b = da_b - da_b.weighted(weights).mean(dim=dim)
1518+
else:
1519+
demeaned_da_a = da_a - da_a.mean(dim=dim)
1520+
demeaned_da_b = da_b - da_b.mean(dim=dim)
14661521

14671522
# 4. Compute covariance along the given dim
14681523
# N.B. `skipna=True` is required or auto-covariance is computed incorrectly. E.g.
14691524
# Try xr.cov(da,da) for da = xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"])
1470-
cov = (demeaned_da_a.conj() * demeaned_da_b).sum(
1471-
dim=dim, skipna=True, min_count=1
1472-
) / (valid_count)
1525+
if weights is not None:
1526+
cov = (
1527+
(demeaned_da_a.conj() * demeaned_da_b)
1528+
.weighted(weights)
1529+
.mean(dim=dim, skipna=True)
1530+
)
1531+
else:
1532+
cov = (demeaned_da_a.conj() * demeaned_da_b).mean(dim=dim, skipna=True)
14731533

14741534
if method == "cov":
1475-
return cov
1535+
# Adjust covariance for degrees of freedom
1536+
valid_count = valid_values.sum(dim)
1537+
adjust = valid_count / (valid_count - ddof)
1538+
# I think the cast is required because of `T_DataArray` + `T_Xarray` (would be
1539+
# the same with `T_DatasetOrArray`)
1540+
# https://github.com/pydata/xarray/pull/8384#issuecomment-1784228026
1541+
return cast(T_DataArray, cov * adjust)
14761542

14771543
else:
1478-
# compute std + corr
1479-
da_a_std = da_a.std(dim=dim)
1480-
da_b_std = da_b.std(dim=dim)
1544+
# Compute std and corr
1545+
if weights is not None:
1546+
da_a_std = da_a.weighted(weights).std(dim=dim)
1547+
da_b_std = da_b.weighted(weights).std(dim=dim)
1548+
else:
1549+
da_a_std = da_a.std(dim=dim)
1550+
da_b_std = da_b.std(dim=dim)
14811551
corr = cov / (da_a_std * da_b_std)
1482-
return corr
1552+
return cast(T_DataArray, corr)
14831553

14841554

14851555
def cross(

xarray/tests/test_computation.py

+91
Original file line numberDiff line numberDiff line change
@@ -1775,6 +1775,97 @@ def test_complex_cov() -> None:
17751775
assert abs(actual.item()) == 2
17761776

17771777

1778+
@pytest.mark.parametrize("weighted", [True, False])
1779+
def test_bilinear_cov_corr(weighted: bool) -> None:
1780+
# Test the bilinear properties of covariance and correlation
1781+
da = xr.DataArray(
1782+
np.random.random((3, 21, 4)),
1783+
coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)},
1784+
dims=("a", "time", "x"),
1785+
)
1786+
db = xr.DataArray(
1787+
np.random.random((3, 21, 4)),
1788+
coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)},
1789+
dims=("a", "time", "x"),
1790+
)
1791+
dc = xr.DataArray(
1792+
np.random.random((3, 21, 4)),
1793+
coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)},
1794+
dims=("a", "time", "x"),
1795+
)
1796+
if weighted:
1797+
weights = xr.DataArray(
1798+
np.abs(np.random.random(4)),
1799+
dims=("x"),
1800+
)
1801+
else:
1802+
weights = None
1803+
k = np.random.random(1)[0]
1804+
1805+
# Test covariance properties
1806+
assert_allclose(
1807+
xr.cov(da + k, db, weights=weights), xr.cov(da, db, weights=weights)
1808+
)
1809+
assert_allclose(
1810+
xr.cov(da, db + k, weights=weights), xr.cov(da, db, weights=weights)
1811+
)
1812+
assert_allclose(
1813+
xr.cov(da + dc, db, weights=weights),
1814+
xr.cov(da, db, weights=weights) + xr.cov(dc, db, weights=weights),
1815+
)
1816+
assert_allclose(
1817+
xr.cov(da, db + dc, weights=weights),
1818+
xr.cov(da, db, weights=weights) + xr.cov(da, dc, weights=weights),
1819+
)
1820+
assert_allclose(
1821+
xr.cov(k * da, db, weights=weights), k * xr.cov(da, db, weights=weights)
1822+
)
1823+
assert_allclose(
1824+
xr.cov(da, k * db, weights=weights), k * xr.cov(da, db, weights=weights)
1825+
)
1826+
1827+
# Test correlation properties
1828+
assert_allclose(
1829+
xr.corr(da + k, db, weights=weights), xr.corr(da, db, weights=weights)
1830+
)
1831+
assert_allclose(
1832+
xr.corr(da, db + k, weights=weights), xr.corr(da, db, weights=weights)
1833+
)
1834+
assert_allclose(
1835+
xr.corr(k * da, db, weights=weights), xr.corr(da, db, weights=weights)
1836+
)
1837+
assert_allclose(
1838+
xr.corr(da, k * db, weights=weights), xr.corr(da, db, weights=weights)
1839+
)
1840+
1841+
1842+
def test_equally_weighted_cov_corr() -> None:
1843+
# Test that equal weights for all values produces same results as weights=None
1844+
da = xr.DataArray(
1845+
np.random.random((3, 21, 4)),
1846+
coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)},
1847+
dims=("a", "time", "x"),
1848+
)
1849+
db = xr.DataArray(
1850+
np.random.random((3, 21, 4)),
1851+
coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)},
1852+
dims=("a", "time", "x"),
1853+
)
1854+
#
1855+
assert_allclose(
1856+
xr.cov(da, db, weights=None), xr.cov(da, db, weights=xr.DataArray(1))
1857+
)
1858+
assert_allclose(
1859+
xr.cov(da, db, weights=None), xr.cov(da, db, weights=xr.DataArray(2))
1860+
)
1861+
assert_allclose(
1862+
xr.corr(da, db, weights=None), xr.corr(da, db, weights=xr.DataArray(1))
1863+
)
1864+
assert_allclose(
1865+
xr.corr(da, db, weights=None), xr.corr(da, db, weights=xr.DataArray(2))
1866+
)
1867+
1868+
17781869
@requires_dask
17791870
def test_vectorize_dask_new_output_dims() -> None:
17801871
# regression test for GH3574

0 commit comments

Comments
 (0)