9
9
import warnings
10
10
from collections import Counter
11
11
from collections .abc import Hashable , Iterable , Iterator , Mapping , Sequence , Set
12
- from typing import TYPE_CHECKING , Any , Callable , Literal , TypeVar , Union , overload
12
+ from typing import TYPE_CHECKING , Any , Callable , Literal , TypeVar , Union , cast , overload
13
13
14
14
import numpy as np
15
15
@@ -1281,7 +1281,11 @@ def apply_ufunc(
1281
1281
1282
1282
1283
1283
def cov (
1284
- da_a : T_DataArray , da_b : T_DataArray , dim : Dims = None , ddof : int = 1
1284
+ da_a : T_DataArray ,
1285
+ da_b : T_DataArray ,
1286
+ dim : Dims = None ,
1287
+ ddof : int = 1 ,
1288
+ weights : T_DataArray | None = None ,
1285
1289
) -> T_DataArray :
1286
1290
"""
1287
1291
Compute covariance between two DataArray objects along a shared dimension.
@@ -1297,6 +1301,8 @@ def cov(
1297
1301
ddof : int, default: 1
1298
1302
If ddof=1, covariance is normalized by N-1, giving an unbiased estimate,
1299
1303
else normalization is by N.
1304
+ weights : DataArray, optional
1305
+ Array of weights.
1300
1306
1301
1307
Returns
1302
1308
-------
@@ -1350,6 +1356,23 @@ def cov(
1350
1356
array([ 0.2 , -0.5 , 1.69333333])
1351
1357
Coordinates:
1352
1358
* space (space) <U2 'IA' 'IL' 'IN'
1359
+ >>> weights = DataArray(
1360
+ ... [4, 2, 1],
1361
+ ... dims=("space"),
1362
+ ... coords=[
1363
+ ... ("space", ["IA", "IL", "IN"]),
1364
+ ... ],
1365
+ ... )
1366
+ >>> weights
1367
+ <xarray.DataArray (space: 3)>
1368
+ array([4, 2, 1])
1369
+ Coordinates:
1370
+ * space (space) <U2 'IA' 'IL' 'IN'
1371
+ >>> xr.cov(da_a, da_b, dim="space", weights=weights)
1372
+ <xarray.DataArray (time: 3)>
1373
+ array([-4.69346939, -4.49632653, -3.37959184])
1374
+ Coordinates:
1375
+ * time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03
1353
1376
"""
1354
1377
from xarray .core .dataarray import DataArray
1355
1378
@@ -1358,11 +1381,18 @@ def cov(
1358
1381
"Only xr.DataArray is supported."
1359
1382
f"Given { [type (arr ) for arr in [da_a , da_b ]]} ."
1360
1383
)
1384
+ if weights is not None :
1385
+ if not isinstance (weights , DataArray ):
1386
+ raise TypeError ("Only xr.DataArray is supported." f"Given { type (weights )} ." )
1387
+ return _cov_corr (da_a , da_b , weights = weights , dim = dim , ddof = ddof , method = "cov" )
1361
1388
1362
- return _cov_corr (da_a , da_b , dim = dim , ddof = ddof , method = "cov" )
1363
1389
1364
-
1365
- def corr (da_a : T_DataArray , da_b : T_DataArray , dim : Dims = None ) -> T_DataArray :
1390
+ def corr (
1391
+ da_a : T_DataArray ,
1392
+ da_b : T_DataArray ,
1393
+ dim : Dims = None ,
1394
+ weights : T_DataArray | None = None ,
1395
+ ) -> T_DataArray :
1366
1396
"""
1367
1397
Compute the Pearson correlation coefficient between
1368
1398
two DataArray objects along a shared dimension.
@@ -1375,6 +1405,8 @@ def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray:
1375
1405
Array to compute.
1376
1406
dim : str, iterable of hashable, "..." or None, optional
1377
1407
The dimension along which the correlation will be computed
1408
+ weights : DataArray, optional
1409
+ Array of weights.
1378
1410
1379
1411
Returns
1380
1412
-------
@@ -1428,6 +1460,23 @@ def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray:
1428
1460
array([ 1., -1., 1.])
1429
1461
Coordinates:
1430
1462
* space (space) <U2 'IA' 'IL' 'IN'
1463
+ >>> weights = DataArray(
1464
+ ... [4, 2, 1],
1465
+ ... dims=("space"),
1466
+ ... coords=[
1467
+ ... ("space", ["IA", "IL", "IN"]),
1468
+ ... ],
1469
+ ... )
1470
+ >>> weights
1471
+ <xarray.DataArray (space: 3)>
1472
+ array([4, 2, 1])
1473
+ Coordinates:
1474
+ * space (space) <U2 'IA' 'IL' 'IN'
1475
+ >>> xr.corr(da_a, da_b, dim="space", weights=weights)
1476
+ <xarray.DataArray (time: 3)>
1477
+ array([-0.50240504, -0.83215028, -0.99057446])
1478
+ Coordinates:
1479
+ * time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03
1431
1480
"""
1432
1481
from xarray .core .dataarray import DataArray
1433
1482
@@ -1436,13 +1485,16 @@ def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray:
1436
1485
"Only xr.DataArray is supported."
1437
1486
f"Given { [type (arr ) for arr in [da_a , da_b ]]} ."
1438
1487
)
1439
-
1440
- return _cov_corr (da_a , da_b , dim = dim , method = "corr" )
1488
+ if weights is not None :
1489
+ if not isinstance (weights , DataArray ):
1490
+ raise TypeError ("Only xr.DataArray is supported." f"Given { type (weights )} ." )
1491
+ return _cov_corr (da_a , da_b , weights = weights , dim = dim , method = "corr" )
1441
1492
1442
1493
1443
1494
def _cov_corr (
1444
1495
da_a : T_DataArray ,
1445
1496
da_b : T_DataArray ,
1497
+ weights : T_DataArray | None = None ,
1446
1498
dim : Dims = None ,
1447
1499
ddof : int = 0 ,
1448
1500
method : Literal ["cov" , "corr" , None ] = None ,
@@ -1458,28 +1510,46 @@ def _cov_corr(
1458
1510
valid_values = da_a .notnull () & da_b .notnull ()
1459
1511
da_a = da_a .where (valid_values )
1460
1512
da_b = da_b .where (valid_values )
1461
- valid_count = valid_values .sum (dim ) - ddof
1462
1513
1463
1514
# 3. Detrend along the given dim
1464
- demeaned_da_a = da_a - da_a .mean (dim = dim )
1465
- demeaned_da_b = da_b - da_b .mean (dim = dim )
1515
+ if weights is not None :
1516
+ demeaned_da_a = da_a - da_a .weighted (weights ).mean (dim = dim )
1517
+ demeaned_da_b = da_b - da_b .weighted (weights ).mean (dim = dim )
1518
+ else :
1519
+ demeaned_da_a = da_a - da_a .mean (dim = dim )
1520
+ demeaned_da_b = da_b - da_b .mean (dim = dim )
1466
1521
1467
1522
# 4. Compute covariance along the given dim
1468
1523
# N.B. `skipna=True` is required or auto-covariance is computed incorrectly. E.g.
1469
1524
# Try xr.cov(da,da) for da = xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"])
1470
- cov = (demeaned_da_a .conj () * demeaned_da_b ).sum (
1471
- dim = dim , skipna = True , min_count = 1
1472
- ) / (valid_count )
1525
+ if weights is not None :
1526
+ cov = (
1527
+ (demeaned_da_a .conj () * demeaned_da_b )
1528
+ .weighted (weights )
1529
+ .mean (dim = dim , skipna = True )
1530
+ )
1531
+ else :
1532
+ cov = (demeaned_da_a .conj () * demeaned_da_b ).mean (dim = dim , skipna = True )
1473
1533
1474
1534
if method == "cov" :
1475
- return cov
1535
+ # Adjust covariance for degrees of freedom
1536
+ valid_count = valid_values .sum (dim )
1537
+ adjust = valid_count / (valid_count - ddof )
1538
+ # I think the cast is required because of `T_DataArray` + `T_Xarray` (would be
1539
+ # the same with `T_DatasetOrArray`)
1540
+ # https://github.com/pydata/xarray/pull/8384#issuecomment-1784228026
1541
+ return cast (T_DataArray , cov * adjust )
1476
1542
1477
1543
else :
1478
- # compute std + corr
1479
- da_a_std = da_a .std (dim = dim )
1480
- da_b_std = da_b .std (dim = dim )
1544
+ # Compute std and corr
1545
+ if weights is not None :
1546
+ da_a_std = da_a .weighted (weights ).std (dim = dim )
1547
+ da_b_std = da_b .weighted (weights ).std (dim = dim )
1548
+ else :
1549
+ da_a_std = da_a .std (dim = dim )
1550
+ da_b_std = da_b .std (dim = dim )
1481
1551
corr = cov / (da_a_std * da_b_std )
1482
- return corr
1552
+ return cast ( T_DataArray , corr )
1483
1553
1484
1554
1485
1555
def cross (
0 commit comments