Skip to content

Commit 88c0726

Browse files
Llorenc LledoLlorenc Lledo
Llorenc Lledo
authored and
Llorenc Lledo
committed
Added function _weighted_cov_corr and modified cov and corr to call it if parameter weights is not None
1 parent ce1af97 commit 88c0726

File tree

1 file changed

+75
-5
lines changed

1 file changed

+75
-5
lines changed

xarray/core/computation.py

+75-5
Original file line numberDiff line numberDiff line change
@@ -1281,7 +1281,11 @@ def apply_ufunc(
12811281

12821282

12831283
def cov(
1284-
da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None, ddof: int = 1
1284+
da_a: T_DataArray,
1285+
da_b: T_DataArray,
1286+
dim: Dims = None,
1287+
ddof: int = 1,
1288+
weights: T_DataArray = None,
12851289
) -> T_DataArray:
12861290
"""
12871291
Compute covariance between two DataArray objects along a shared dimension.
@@ -1297,6 +1301,8 @@ def cov(
12971301
ddof : int, default: 1
12981302
If ddof=1, covariance is normalized by N-1, giving an unbiased estimate,
12991303
else normalization is by N.
1304+
weights : DataArray, default: None
1305+
Array of weights.
13001306
13011307
Returns
13021308
-------
@@ -1358,11 +1364,22 @@ def cov(
13581364
"Only xr.DataArray is supported."
13591365
f"Given {[type(arr) for arr in [da_a, da_b]]}."
13601366
)
1361-
1362-
return _cov_corr(da_a, da_b, dim=dim, ddof=ddof, method="cov")
1367+
if weights is not None:
1368+
if not isinstance(weights, DataArray):
1369+
raise TypeError(
1370+
"Only xr.DataArray is supported."
1371+
f"Given {type(weights)}."
1372+
)
1373+
return _weighted_cov_corr(da_a, da_b, weights=weights, dim=dim, ddof=ddof, method="cov")
1374+
else
1375+
return _cov_corr(da_a, da_b, dim=dim, ddof=ddof, method="cov")
13631376

13641377

1365-
def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray:
1378+
def corr(da_a: T_DataArray,
1379+
da_b: T_DataArray,
1380+
dim: Dims = None,
1381+
weights: T_DataArray = None,
1382+
) -> T_DataArray:
13661383
"""
13671384
Compute the Pearson correlation coefficient between
13681385
two DataArray objects along a shared dimension.
@@ -1375,6 +1392,8 @@ def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray:
13751392
Array to compute.
13761393
dim : str, iterable of hashable, "..." or None, optional
13771394
The dimension along which the correlation will be computed
1395+
weights : DataArray, default: None
1396+
Array of weights.
13781397
13791398
Returns
13801399
-------
@@ -1437,7 +1456,15 @@ def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray:
14371456
f"Given {[type(arr) for arr in [da_a, da_b]]}."
14381457
)
14391458

1440-
return _cov_corr(da_a, da_b, dim=dim, method="corr")
1459+
if weights is not None:
1460+
if not isinstance(weights, DataArray):
1461+
raise TypeError(
1462+
"Only xr.DataArray is supported."
1463+
f"Given {type(weights)}."
1464+
)
1465+
return _weighted_cov_corr(da_a, da_b, weights=weights, dim=dim, method="corr")
1466+
else
1467+
return _cov_corr(da_a, da_b, dim=dim, method="corr")
14411468

14421469

14431470
def _cov_corr(
@@ -1481,6 +1508,49 @@ def _cov_corr(
14811508
corr = cov / (da_a_std * da_b_std)
14821509
return corr
14831510

1511+
def _weighted_cov_corr(
1512+
da_a: T_DataArray,
1513+
da_b: T_DataArray,
1514+
weights: T_DataArray,
1515+
dim: Dims = None,
1516+
ddof: int = 0,
1517+
method: Literal["cov", "corr", None] = None,
1518+
) -> T_DataArray:
1519+
"""
1520+
Internal method for weighted xr.cov() and xr.corr(), extending
1521+
_cov_corr() functionality.
1522+
"""
1523+
# 1. Broadcast the two arrays
1524+
da_a, da_b = align(da_a, da_b, join="inner", copy=False)
1525+
1526+
# 2. Ignore the nans
1527+
valid_values = da_a.notnull() & da_b.notnull()
1528+
da_a = da_a.where(valid_values)
1529+
da_b = da_b.where(valid_values)
1530+
1531+
# 3. Detrend along the given dim
1532+
demeaned_da_a = da_a - da_a.weighted(weights).mean(dim=dim)
1533+
demeaned_da_b = da_b - da_b.weighted(weights).mean(dim=dim)
1534+
1535+
# 4. Compute covariance along the given dim
1536+
# N.B. `skipna=True` is required or auto-covariance is computed incorrectly. E.g.
1537+
# Try xr.cov(da,da) for da = xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"])
1538+
cov = (demeaned_da_a.conj() * demeaned_da_b).weighted(weights).mean(
1539+
dim=dim, skipna=True, min_count=1
1540+
)
1541+
1542+
if method == "cov":
1543+
# Adjust covariance for degrees of freedom
1544+
valid_count = valid_values.sum(dim)
1545+
adjust = valid_count / (valid_count - ddof)
1546+
return cov * adjust
1547+
1548+
else:
1549+
# Compute std and corr
1550+
da_a_std = da_a.weighted(weights).std(dim=dim)
1551+
da_b_std = da_b.weighted(weights).std(dim=dim)
1552+
corr = cov / (da_a_std * da_b_std)
1553+
return corr
14841554

14851555
def cross(
14861556
a: DataArray | Variable, b: DataArray | Variable, *, dim: Hashable

0 commit comments

Comments
 (0)