Skip to content

Rewrite interp to use apply_ufunc #9881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Dec 19, 2024
Merged

Conversation

dcherian
Copy link
Contributor

@dcherian dcherian commented Dec 13, 2024

  1. Removes a bunch of complexity around interpolating dask arrays by using apply_ufunc instead of blockwise directly.
  2. A major improvement is that we can now use vectorize=True to get sane dask graphs for vectorized interpolation to chunked arrays (interp performance with chunked dimensions #6799 (comment))
  3. Added a bunch of typing.
  4. Happily this fixes Interpolation with multiple mutlidimensional arrays sharing dims fails #4463

cc @ks905383 your vectorized interpolation example now has this graph:
image

instead of this quadratic monstrosity
image

@dcherian dcherian added needs review run-benchmark Run the ASV benchmark workflow labels Dec 13, 2024
@dcherian dcherian requested a review from Illviljan December 13, 2024 06:32
@@ -4127,18 +4119,6 @@ def interp(

coords = either_dict_or_kwargs(coords, coords_kwargs, "interp")
indexers = dict(self._validate_interp_indexers(coords))

if coords:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handled by vectorize=True. This is possibly a perf regression with numpy arrays, but a massive improvement with chunked arrays.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For posterity the bad thing about this approach is that it can greatly expand the number of core dimensions for the problem, limiting the potential for parallelism.

Consider the problem in #6799 (comment). In the following, dimension names are listed out in [].

da[time, q, lat, lon].interp(q=bar[lat,lon]) gets rewritten to da[time,q,lat,lon].interp(q=bar[lat, lon], lat=lat[lat], lon=lon[lon]) which thanks to our automatic rechunking, makes dask merge chunks in lat, lon too, for no benefit.

def _chunked_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=True):
"""Wrapper for `_interpnd` through `blockwise` for chunked arrays.

def _interpnd(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I merged in two functions to reduce indirection and make it easier to read.

exclude_dims=all_in_core_dims,
dask="parallelized",
kwargs=dict(interp_func=func, interp_kwargs=kwargs),
dask_gufunc_kwargs=dict(output_sizes=output_sizes, allow_rechunk=True),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allow_rechunk=True matches the current behaviour where we rechunk along all core dimensions to a single chunk.

@dcherian dcherian force-pushed the redo-blockwise-interp branch from 245697e to a5e1854 Compare December 14, 2024 00:06
@dcherian dcherian force-pushed the redo-blockwise-interp branch from 652a239 to 586f638 Compare December 14, 2024 04:02
@Illviljan Illviljan mentioned this pull request Dec 14, 2024
1 task
@dcherian
Copy link
Contributor Author

Merging on thursday if there are no comments.

IMO this is a big win for maintainability.

@dcherian dcherian added plan to merge Final call for comments and removed needs review labels Dec 17, 2024
@@ -566,29 +577,30 @@ def _get_valid_fill_mask(arr, dim, limit):
) <= limit


def _localize(var, indexes_coords):
def _localize(obj: T, indexes_coords: SourceDest) -> tuple[T, SourceDest]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should use T_Xarray instead of a plain T to get rid of the type ignore at return.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't have Variable, so I'd have to make a new T_DatasetOrVariable or a protocol with .isel perhaps?

Copy link
Contributor

@Illviljan Illviljan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarks still looks good. Nice work!

Comment on lines +830 to +831
# TODO: narrow interp_func to interpolator here
return _interp1d(var, x_list, new_x_list, interp_func, interp_kwargs) # type: ignore[arg-type]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mypy is correct to error here right?
_interp1d calls interp_func(...)(....) and that should crash with a InterpCallable?
Is there a pytest with interp_func: InterpCallable?
Is InterpCallable necessary? Would be nice to just remove it...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it depends on whether we end up using get_interpolator or get_interpolator_nd. I'm sure there's a test but can't remember which off the top of my head.

@dcherian dcherian merged commit 29fe679 into pydata:main Dec 19, 2024
29 checks passed
@dcherian dcherian deleted the redo-blockwise-interp branch December 19, 2024 16:30
dcherian added a commit to dcherian/xarray that referenced this pull request Mar 19, 2025
* main: (63 commits)
  Fix zarr upstream tests (pydata#9927)
  Update pre-commit hooks (pydata#9925)
  split out CFDatetimeCoder, deprecate use_cftime as kwarg (pydata#9901)
  dev whats-new (pydata#9923)
  Whats-new 2025.01.0 (pydata#9919)
  Silence upstream Zarr warnings (pydata#9920)
  time coding refactor (pydata#9906)
  fix warning from scipy backend guess_can_open on directory (pydata#9911)
  Enhance and move ISO-8601 parser to coding.times (pydata#9899)
  Edit serialization error message (pydata#9916)
  friendlier error messages for missing chunk managers (pydata#9676)
  Bump codecov/codecov-action from 5.1.1 to 5.1.2 in the actions group (pydata#9915)
  Rewrite interp to use `apply_ufunc` (pydata#9881)
  Skip dask rolling (pydata#9909)
  Explicitly configure ReadTheDocs build to use conf.py (pydata#9908)
  Cache pre-existing Zarr arrays in Zarr backend (pydata#9861)
  Optimize idxmin, idxmax with dask (pydata#9800)
  remove unused "type: ignore" comments in test_plot.py (fixed in matplotlib 3.10.0) (pydata#9904)
  move scalar-handling logic into `possibly_convert_objects` (pydata#9900)
  Add missing DataTree attributes to docs (pydata#9876)
  ...
dcherian added a commit to dcherian/xarray that referenced this pull request May 28, 2025
dcherian added a commit to dcherian/xarray that referenced this pull request May 28, 2025
dcherian added a commit to dcherian/xarray that referenced this pull request May 28, 2025
dcherian added a commit to dcherian/xarray that referenced this pull request May 30, 2025
* main:
  Fix performance regression in interp from pydata#9881 (pydata#10370)
  html repr: improve style for dropdown sections (pydata#10354)
  Grouper tweaks. (pydata#10362)
  Docs: Add links to getting help mermaid diagram (pydata#10324)
  Enforce ruff/flynt rules (FLY) (pydata#10375)
  Add missing AbstractWritableDataStore base methods and arguments (pydata#10343)
  Improve html repr in dark mode (Jupyterlab + Xarray docs) (pydata#10353)
  Pin Mypy to 1.15 (pydata#10378)
  use numpy dtype exposed by zarr array instead of metadata.data_type (pydata#10348)
  Fix doc typo for caption "Interoperability" (pydata#10374)
  Implement cftime vectorization as discussed in PR pydata#8322 (pydata#8324)
  Enforce ruff/flake8-pyi rules (PYI) (pydata#10359)
  Apply assorted ruff/Pylint rules (PL) / Enforce PLE rules (pydata#10366)
  (fix): pandas extension array repr for int64[pyarrow] (pydata#10317)
  Enforce ruff/flake8-implicit-str-concat rules (ISC) (pydata#10368)
  Enforce ruff/refurb rules (FURB) (pydata#10367)
  Ignore ruff/Pyflakes rule F401 more precisely (pydata#10369)
  Apply assorted ruff/flake8-simplify rules (SIM) (pydata#10364)
  Apply assorted ruff/flake8-pytest-style rules (PT) (pydata#10363)
  Fix "a array" misspelling (pydata#10365)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
plan to merge Final call for comments run-benchmark Run the ASV benchmark workflow topic-interpolation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Interpolation with multiple mutlidimensional arrays sharing dims fails
3 participants