Skip to content

Automatic Dask-Zarr chunk alignment #10336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2132,6 +2132,7 @@ def to_zarr(
append_dim: Hashable | None = None,
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
safe_chunks: bool = True,
align_chunks: bool = False,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
write_empty_chunks: bool | None = None,
Expand All @@ -2155,6 +2156,7 @@ def to_zarr(
append_dim: Hashable | None = None,
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
safe_chunks: bool = True,
align_chunks: bool = False,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
write_empty_chunks: bool | None = None,
Expand All @@ -2176,6 +2178,7 @@ def to_zarr(
append_dim: Hashable | None = None,
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
safe_chunks: bool = True,
align_chunks: bool = False,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
zarr_format: int | None = None,
Expand Down Expand Up @@ -2225,13 +2228,16 @@ def to_zarr(
append_dim=append_dim,
write_region=region,
safe_chunks=safe_chunks,
align_chunks=align_chunks,
zarr_version=zarr_version,
zarr_format=zarr_format,
write_empty=write_empty_chunks,
**kwargs,
)

dataset = zstore._validate_and_autodetect_region(dataset)
dataset = zstore._validate_and_autodetect_region(
dataset,
)
zstore._validate_encoding(encoding)

writer = ArrayWriter()
Expand Down
177 changes: 177 additions & 0 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,3 +744,180 @@ def open_groups_as_dict(

# mapping of engine name to (module name, BackendEntrypoint Class)
BACKEND_ENTRYPOINTS: dict[str, tuple[str | None, type[BackendEntrypoint]]] = {}


class ChunksUtilities:
@staticmethod
def get_aligned_chunks(
nd_var_chunks: tuple[tuple[int, ...], ...],
nd_backend_chunks: tuple[tuple[int, ...], ...],
) -> tuple[tuple[int], ...]:
if len(nd_backend_chunks) != len(nd_var_chunks):
raise ValueError(
"The number of dimensions on the backend and the variable "
"must be the same."
)

nd_aligned_chunks = []
for backend_chunks, var_chunks in zip(
nd_var_chunks, nd_backend_chunks, strict=True
):
# Let's create a mutable copy of the var_chunks
var_chunks = list(var_chunks)

# Validate that they have the same number of elements
if sum(backend_chunks) != sum(var_chunks):
raise ValueError(
"The number of elements on the backend is different than "
"the number of elements on the variable,"
"this should never happen at this point."
)

# Validate if the backend_chunks satisfy the condition that all the values
# excluding the borders are equal
if len(set(backend_chunks[1:-1])) > 1:
raise ValueError(
"For the moment this function only support aligning chunks "
"when the backend chunks are equal, excluding the borders. "
"In other words the backend chunks must have a grid form"
)

# The algorithm assumes that there are always two borders on the
# Backend and the Array if not, the result is going to be the same
# as the input, and there is nothing to optimize
if len(backend_chunks) == 1:
nd_aligned_chunks.append(backend_chunks)
continue

if len(var_chunks) == 1:
nd_aligned_chunks.append(var_chunks)
continue

# Size of the chunk on the backend
fixed_chunk = max(backend_chunks)

# The ideal size of the chunks is the maximum of the two; this would avoid
# that we use more memory than expected
max_chunk = max(fixed_chunk, max(var_chunks))

# The algorithm assumes that the chunks on this array are aligned except the last one
# because it can be considered a partial one
aligned_chunks = []

# For simplicity of the algorithm, let's transform the Array chunks in such a way that
# we remove the partial chunks. To achieve this, we add artificial data to the borders
var_chunks[0] += fixed_chunk - backend_chunks[0]
var_chunks[-1] += fixed_chunk - backend_chunks[-1]

# The unfilled_size is the amount of space that has not been filled on the last
# processed chunk; this is equivalent to the amount of data that would need to be
# added to a partial Zarr chunk to fill it up to the fixed_chunk size
unfilled_size = 0

for var_chunk in var_chunks:
# Ideally, we should try to preserve the original Dask chunks, but this is only
# possible if the last processed chunk was aligned (unfilled_size == 0)
ideal_chunk = var_chunk
if unfilled_size:
# If that scenario is not possible, the best option is to merge the chunks
ideal_chunk = var_chunk + aligned_chunks[-1]

while ideal_chunk:
if not unfilled_size:
# If the previous chunk is filled, let's add a new chunk
# of size 0 that will be used on the merging step to simplify the algorithm
aligned_chunks.append(0)

if ideal_chunk > max_chunk:
# If the ideal_chunk is bigger than the max_chunk,
# we need to increase the last chunk as much as possible
# but keeping it aligned, and then add a new chunk
max_increase = max_chunk - aligned_chunks[-1]
max_increase = (
max_increase - (max_increase - unfilled_size) % fixed_chunk
)
aligned_chunks[-1] += max_increase
else:
# Perfect scenario where the chunks can be merged without any split.
aligned_chunks[-1] = ideal_chunk

ideal_chunk -= aligned_chunks[-1]
unfilled_size = (
fixed_chunk - aligned_chunks[-1] % fixed_chunk
) % fixed_chunk

# Now we have to remove the artificial data added to the borders
for order in [-1, 1]:
border_size = fixed_chunk - backend_chunks[::order][0]
aligned_chunks = aligned_chunks[::order]
aligned_chunks[0] -= border_size
var_chunks = var_chunks[::order]
var_chunks[0] -= border_size
if (
len(aligned_chunks) >= 2
and aligned_chunks[0] + aligned_chunks[1] <= max_chunk
and aligned_chunks[0] != var_chunks[0]
):
# The artificial data added to the border can introduce inefficient chunks
# on the borders, for that reason, we will check if we can merge them or not
# Example:
# backend_chunks = [6, 6, 1]
# var_chunks = [6, 7]
# transformed_var_chunks = [6, 12]
# The ideal output should preserve the same var_chunks, but the previous loop
# is going to produce aligned_chunks = [6, 6, 6]
# And after removing the artificial data, we will end up with aligned_chunks = [6, 6, 1]
# which is not ideal and can be merged into a single chunk
aligned_chunks[1] += aligned_chunks[0]
aligned_chunks = aligned_chunks[1:]

var_chunks = var_chunks[::order]
aligned_chunks = aligned_chunks[::order]

nd_aligned_chunks.append(tuple(aligned_chunks))

return tuple(nd_aligned_chunks)

@staticmethod
def get_chunks_on_region(n_elements: int, region: slice, chunk_size: int):
if region is None:
region = slice(0, n_elements)
# Generate the zarr chunks inside the region of this dim
chunks_on_region = [chunk_size - (region.start % chunk_size)]
chunks_on_region.extend(
[chunk_size] * ((n_elements - chunks_on_region[0]) // chunk_size)
)
if (n_elements - chunks_on_region[0]) % chunk_size != 0:
chunks_on_region.append((n_elements - chunks_on_region[0]) % chunk_size)
return chunks_on_region

@staticmethod
def align_variable_chunks(
v: Variable,
enc_chunks: tuple[int, ...],
regions: tuple[slice, ...],
) -> Variable:
nd_var_chunks = v.chunks
if not nd_var_chunks:
return v

nd_backend_chunks = []
for var_chunks, chunk_size, region in zip(
nd_var_chunks, enc_chunks, regions, strict=True
):
nd_backend_chunks.append(
ChunksUtilities.get_chunks_on_region(
sum(var_chunks),
region,
chunk_size,
)
)

aligned_chunks = ChunksUtilities.get_aligned_chunks(
nd_var_chunks=nd_var_chunks,
nd_backend_chunks=tuple(nd_backend_chunks),
)
v = v.chunk(
{dim: chunk for dim, chunk in zip(v.dims, aligned_chunks, strict=True)}
)
return v
28 changes: 26 additions & 2 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
AbstractWritableDataStore,
BackendArray,
BackendEntrypoint,
ChunksUtilities,
_encode_variable_name,
_normalize_path,
datatree_from_dict_with_io_cleanup,
Expand Down Expand Up @@ -621,6 +622,7 @@ class ZarrStore(AbstractWritableDataStore):
"""Store for reading and writing data via zarr"""

__slots__ = (
"_align_chunks",
"_append_dim",
"_cache_members",
"_close_store_on_close",
Expand Down Expand Up @@ -651,6 +653,7 @@ def open_store(
append_dim=None,
write_region=None,
safe_chunks=True,
align_chunks=False,
zarr_version=None,
zarr_format=None,
use_zarr_fill_value_as_mask=None,
Expand Down Expand Up @@ -698,6 +701,7 @@ def open_store(
write_empty,
close_store_on_close,
use_zarr_fill_value_as_mask,
align_chunks=align_chunks,
cache_members=cache_members,
)
for group, group_store in group_members.items()
Expand All @@ -718,6 +722,7 @@ def open_group(
append_dim=None,
write_region=None,
safe_chunks=True,
align_chunks=False,
zarr_version=None,
zarr_format=None,
use_zarr_fill_value_as_mask=None,
Expand Down Expand Up @@ -753,7 +758,8 @@ def open_group(
write_empty,
close_store_on_close,
use_zarr_fill_value_as_mask,
cache_members,
align_chunks=align_chunks,
cache_members=cache_members,
)

def __init__(
Expand All @@ -767,8 +773,13 @@ def __init__(
write_empty: bool | None = None,
close_store_on_close: bool = False,
use_zarr_fill_value_as_mask=None,
align_chunks: bool = False,
cache_members: bool = True,
):
if align_chunks:
# Disabled the safe_chunks validations if the alignment is going to be applied
safe_chunks = False

self.zarr_group = zarr_group
self._read_only = self.zarr_group.read_only
self._synchronizer = self.zarr_group.synchronizer
Expand All @@ -777,6 +788,7 @@ def __init__(
self._consolidate_on_close = consolidate_on_close
self._append_dim = append_dim
self._write_region = write_region
self._align_chunks = align_chunks
self._safe_chunks = safe_chunks
self._write_empty = write_empty
self._close_store_on_close = close_store_on_close
Expand Down Expand Up @@ -1139,7 +1151,13 @@ def _create_new_array(
zarr_array = _put_attrs(zarr_array, attrs)
return zarr_array

def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=None):
def set_variables(
self,
variables: dict[str, Variable],
check_encoding_set,
writer,
unlimited_dims=None,
):
"""
This provides a centralized method to set the variables on the data
store.
Expand Down Expand Up @@ -1244,6 +1262,12 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
attrs=encoded_attrs,
)

if self._align_chunks:
v = ChunksUtilities.align_variable_chunks(
v,
encoding["chunks"],
region,
)
writer.add(v.data, zarr_array, region)

def close(self) -> None:
Expand Down
3 changes: 3 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4214,6 +4214,7 @@ def to_zarr(
append_dim: Hashable | None = None,
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
safe_chunks: bool = True,
align_chunks: bool = False,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
zarr_format: int | None = None,
Expand All @@ -4237,6 +4238,7 @@ def to_zarr(
append_dim: Hashable | None = None,
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
safe_chunks: bool = True,
align_chunks: bool = False,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
zarr_format: int | None = None,
Expand All @@ -4258,6 +4260,7 @@ def to_zarr(
append_dim: Hashable | None = None,
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
safe_chunks: bool = True,
align_chunks: bool = False,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
zarr_format: int | None = None,
Expand Down
3 changes: 3 additions & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2057,6 +2057,7 @@ def to_zarr(
append_dim: Hashable | None = None,
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
safe_chunks: bool = True,
align_chunks: bool = False,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
zarr_format: int | None = None,
Expand All @@ -2080,6 +2081,7 @@ def to_zarr(
append_dim: Hashable | None = None,
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
safe_chunks: bool = True,
align_chunks: bool = False,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
zarr_format: int | None = None,
Expand All @@ -2101,6 +2103,7 @@ def to_zarr(
append_dim: Hashable | None = None,
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
safe_chunks: bool = True,
align_chunks: bool = False,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
zarr_format: int | None = None,
Expand Down
15 changes: 14 additions & 1 deletion xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
open_mfdataset,
save_mfdataset,
)
from xarray.backends.common import robust_getitem
from xarray.backends.common import ChunksUtilities, robust_getitem
from xarray.backends.h5netcdf_ import H5netcdfBackendEntrypoint
from xarray.backends.netcdf3 import _nc3_dtype_coercions
from xarray.backends.netCDF4_ import (
Expand Down Expand Up @@ -832,7 +832,7 @@
)

with self.roundtrip(ds) as on_disk:
subset = on_disk.isel(t=[0], p=0).z[:, ::10, ::10][:, ::-1, :]

Check failure on line 835 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.10

TestNetCDF4ViaDaskData.test_outer_indexing_reversed ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4

Check failure on line 835 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.10

TestH5NetCDFViaDaskData.test_outer_indexing_reversed ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4

Check failure on line 835 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.10

TestDask.test_outer_indexing_reversed ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4

Check failure on line 835 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.13

TestH5NetCDFViaDaskData.test_outer_indexing_reversed ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4

Check failure on line 835 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.13

TestNetCDF4ViaDaskData.test_outer_indexing_reversed ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4

Check failure on line 835 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.13

TestDask.test_outer_indexing_reversed ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4

Check failure on line 835 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.13 all-but-numba

TestH5NetCDFViaDaskData.test_outer_indexing_reversed ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4

Check failure on line 835 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.13 all-but-numba

TestNetCDF4ViaDaskData.test_outer_indexing_reversed ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4

Check failure on line 835 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.13 all-but-numba

TestDask.test_outer_indexing_reversed ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4

Check failure on line 835 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10

TestNetCDF4ViaDaskData.test_outer_indexing_reversed ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4

Check failure on line 835 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10

TestH5NetCDFViaDaskData.test_outer_indexing_reversed ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4

Check failure on line 835 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10

TestDask.test_outer_indexing_reversed ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4

Check failure on line 835 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.13

TestH5NetCDFViaDaskData.test_outer_indexing_reversed ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4

Check failure on line 835 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.13

TestNetCDF4ViaDaskData.test_outer_indexing_reversed ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4

Check failure on line 835 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.13

TestDask.test_outer_indexing_reversed ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4
assert subset.sizes == subset.load().sizes

def test_isel_dataarray(self) -> None:
Expand Down Expand Up @@ -6759,3 +6759,16 @@
storage_options={"skip_instance_cache": False},
) as ds:
assert_identical(xr.concat([ds1, ds2], dim="time"), ds)


def test_align_variable_chunks():
arr = xr.DataArray(
list(range(11)), dims=["a"], coords={"a": list(range(11))}, name="foo"
)
region_arr = arr.isel(a=slice(0, 5)).chunk(a=(3, 1, 1))

Check failure on line 6768 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 bare-minimum

test_align_variable_chunks ImportError: chunk manager 'dask' is not available. Please make sure 'dask' is installed and importable.

Check failure on line 6768 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.12 all-but-dask

test_align_variable_chunks ImportError: chunk manager 'dask' is not available. Please make sure 'dask' is installed and importable.
result = ChunksUtilities.align_variable_chunks(
region_arr.variable,
enc_chunks=(3,),
regions=(slice(0, 5),),
)
assert result.chunks == ((3, 2),)
Loading