Skip to content
forked from pydata/xarray

Commit 39582db

Browse files
committed
Optimize polyfit
Closes pydata#5629 1. Use Variable instead of DataArray 2. Use `reshape_blockwise` when possible following pydata#5629 (comment)
1 parent 91962d6 commit 39582db

File tree

6 files changed

+100
-35
lines changed

6 files changed

+100
-35
lines changed

xarray/core/dask_array_compat.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from typing import Any
2+
3+
from xarray.namedarray.utils import module_available
4+
5+
6+
def reshape_blockwise(
7+
x: Any,
8+
shape: int | tuple[int, ...],
9+
chunks: tuple[tuple[int, ...], ...] | None = None,
10+
):
11+
if module_available("dask", "2024.08.2"):
12+
from dask.array import reshape_blockwise
13+
14+
return reshape_blockwise(x, shape=shape, chunks=chunks)
15+
else:
16+
return x.reshape(shape)

xarray/core/dask_array_ops.py

+22
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import math
4+
35
from xarray.core import dtypes, nputils
46

57

@@ -29,6 +31,15 @@ def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1):
2931
def least_squares(lhs, rhs, rcond=None, skipna=False):
3032
import dask.array as da
3133

34+
from xarray.core.dask_array_compat import reshape_blockwise
35+
36+
if rhs.ndim > 2:
37+
out_shape = rhs.shape
38+
reshape_chunks = rhs.chunks
39+
rhs = reshape_blockwise(rhs, (rhs.shape[0], math.prod(rhs.shape[1:])))
40+
else:
41+
out_shape = None
42+
3243
lhs_da = da.from_array(lhs, chunks=(rhs.chunks[0], lhs.shape[1]))
3344
if skipna:
3445
added_dim = rhs.ndim == 1
@@ -52,6 +63,17 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
5263
# Residuals here are (1, 1) but should be (K,) as rhs is (N, K)
5364
# See issue dask/dask#6516
5465
coeffs, residuals, _, _ = da.linalg.lstsq(lhs_da, rhs)
66+
67+
if out_shape is not None:
68+
coeffs = reshape_blockwise(
69+
coeffs,
70+
shape=(coeffs.shape[0], *out_shape[1:]),
71+
chunks=((coeffs.shape[0],), *reshape_chunks[1:]),
72+
)
73+
residuals = reshape_blockwise(
74+
residuals, shape=out_shape[1:], chunks=reshape_chunks[1:]
75+
)
76+
5577
return coeffs, residuals
5678

5779

xarray/core/dataset.py

+39-34
Original file line numberDiff line numberDiff line change
@@ -9132,34 +9132,39 @@ def polyfit(
91329132
variables[sing.name] = sing
91339133

91349134
# If we have a coordinate get its underlying dimension.
9135-
true_dim = self.coords[dim].dims[0]
9135+
(true_dim,) = self.coords[dim].dims
91369136

9137-
for name, da in self.data_vars.items():
9138-
if true_dim not in da.dims:
9137+
other_coords = {
9138+
dim: self._variables[dim]
9139+
for dim in set(self.dims) - {true_dim}
9140+
if dim in self._variables
9141+
}
9142+
present_dims = set()
9143+
for name, var in self._variables.items():
9144+
if name in self._coord_names or name in self.dims:
9145+
continue
9146+
if true_dim not in var.dims:
91399147
continue
91409148

9141-
if is_duck_dask_array(da.data) and (
9149+
if is_duck_dask_array(var._data) and (
91429150
rank != order or full or skipna is None
91439151
):
91449152
# Current algorithm with dask and skipna=False neither supports
91459153
# deficient ranks nor does it output the "full" info (issue dask/dask#6516)
91469154
skipna_da = True
91479155
elif skipna is None:
9148-
skipna_da = bool(np.any(da.isnull()))
9149-
9150-
dims_to_stack = [dimname for dimname in da.dims if dimname != true_dim]
9151-
stacked_coords: dict[Hashable, DataArray] = {}
9152-
if dims_to_stack:
9153-
stacked_dim = utils.get_temp_dimname(dims_to_stack, "stacked")
9154-
rhs = da.transpose(true_dim, *dims_to_stack).stack(
9155-
{stacked_dim: dims_to_stack}
9156-
)
9157-
stacked_coords = {stacked_dim: rhs[stacked_dim]}
9158-
scale_da = scale[:, np.newaxis]
9156+
skipna_da = bool(np.any(var.isnull()))
9157+
9158+
if var.ndim > 1:
9159+
rhs = var.transpose(true_dim, ...)
9160+
other_dims = rhs.dims[1:]
9161+
scale_da = scale.reshape(-1, *((1,) * len(other_dims)))
91599162
else:
9160-
rhs = da
9163+
rhs = var
91619164
scale_da = scale
9165+
other_dims = ()
91629166

9167+
present_dims.update(*other_dims)
91639168
if w is not None:
91649169
rhs = rhs * w[:, np.newaxis]
91659170

@@ -9179,26 +9184,15 @@ def polyfit(
91799184
# Thus a ReprObject => polyfit was called on a DataArray
91809185
name = ""
91819186

9182-
coeffs = DataArray(
9183-
coeffs / scale_da,
9184-
dims=[degree_dim] + list(stacked_coords.keys()),
9185-
coords={degree_dim: np.arange(order)[::-1], **stacked_coords},
9186-
name=name + "polyfit_coefficients",
9187-
)
9188-
if dims_to_stack:
9189-
coeffs = coeffs.unstack(stacked_dim)
9190-
variables[coeffs.name] = coeffs
9187+
coeffs = Variable(data=coeffs / scale_da, dims=(degree_dim,) + other_dims)
9188+
variables[name + "polyfit_coefficients"] = coeffs
91919189

91929190
if full or (cov is True):
9193-
residuals = DataArray(
9194-
residuals if dims_to_stack else residuals.squeeze(),
9195-
dims=list(stacked_coords.keys()),
9196-
coords=stacked_coords,
9197-
name=name + "polyfit_residuals",
9191+
residuals = Variable(
9192+
data=residuals if var.ndim > 1 else residuals.squeeze(),
9193+
dims=other_dims,
91989194
)
9199-
if dims_to_stack:
9200-
residuals = residuals.unstack(stacked_dim)
9201-
variables[residuals.name] = residuals
9195+
variables[name + "polyfit_residuals"] = residuals
92029196

92039197
if cov:
92049198
Vbase = np.linalg.inv(np.dot(lhs.T, lhs))
@@ -9214,7 +9208,18 @@ def polyfit(
92149208
covariance = DataArray(Vbase, dims=("cov_i", "cov_j")) * fac
92159209
variables[name + "polyfit_covariance"] = covariance
92169210

9217-
return type(self)(data_vars=variables, attrs=self.attrs.copy())
9211+
return type(self)(
9212+
data_vars=variables,
9213+
coords={
9214+
degree_dim: np.arange(order)[::-1],
9215+
**{
9216+
name: coord
9217+
for name, coord in other_coords.items()
9218+
if name in present_dims
9219+
},
9220+
},
9221+
attrs=self.attrs.copy(),
9222+
)
92189223

92199224
def pad(
92209225
self,

xarray/core/nputils.py

+10
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,12 @@ def warn_on_deficient_rank(rank, order):
255255

256256

257257
def least_squares(lhs, rhs, rcond=None, skipna=False):
258+
if rhs.ndim > 2:
259+
out_shape = rhs.shape
260+
rhs = rhs.reshape(rhs.shape[0], -1)
261+
else:
262+
out_shape = None
263+
258264
if skipna:
259265
added_dim = rhs.ndim == 1
260266
if added_dim:
@@ -281,6 +287,10 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
281287
if residuals.size == 0:
282288
residuals = coeffs[0] * np.nan
283289
warn_on_deficient_rank(rank, lhs.shape[1])
290+
291+
if out_shape is not None:
292+
coeffs = coeffs.reshape(-1, *out_shape[1:])
293+
residuals = residuals.reshape(*out_shape[1:])
284294
return coeffs, residuals
285295

286296

xarray/tests/test_dataarray.py

+12
Original file line numberDiff line numberDiff line change
@@ -4308,6 +4308,18 @@ def test_polyfit(self, use_dask, use_datetime) -> None:
43084308
out = da.polyfit("x", 8, full=True)
43094309
np.testing.assert_array_equal(out.polyfit_residuals.isnull(), [True, False])
43104310

4311+
@requires_dask
4312+
def test_polyfit_nd_dask(self) -> None:
4313+
da = (
4314+
DataArray(np.arange(120), dims="time", coords={"time": np.arange(120)})
4315+
.chunk({"time": 20})
4316+
.expand_dims(lat=5, lon=5)
4317+
.chunk({"lat": 2, "lon": 2})
4318+
)
4319+
actual = da.polyfit("time", 1, skipna=False)
4320+
expected = da.compute().polyfit("time", 1, skipna=False)
4321+
assert_allclose(actual, expected)
4322+
43114323
def test_pad_constant(self) -> None:
43124324
ar = DataArray(np.arange(3 * 4 * 5).reshape(3, 4, 5))
43134325
actual = ar.pad(dim_0=(1, 3))

xarray/tests/test_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6698,7 +6698,7 @@ def test_polyfit_coord(self) -> None:
66986698

66996699
out = ds.polyfit("numbers", 2, full=False)
67006700
assert "var3_polyfit_coefficients" in out
6701-
assert "dim1" in out
6701+
assert "dim1" in out.dims
67026702
assert "dim2" not in out
67036703
assert "dim3" not in out
67046704

0 commit comments

Comments
 (0)