Skip to content

Automatic region detection and transpose for to_zarr() #8434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Nov 14, 2023
Merged
13 changes: 8 additions & 5 deletions doc/user-guide/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -876,17 +876,20 @@ and then calling ``to_zarr`` with ``compute=False`` to write only metadata
ds.to_zarr(path, compute=False)

Now, a Zarr store with the correct variable shapes and attributes exists that
can be filled out by subsequent calls to ``to_zarr``. The ``region`` provides a
mapping from dimension names to Python ``slice`` objects indicating where the
data should be written (in index space, not coordinate space), e.g.,
can be filled out by subsequent calls to ``to_zarr``. ``region`` can be
specified as ``"auto"``, which opens the existing store and determines the
correct alignment of the new data with the existing coordinates, or as an
explicit mapping from dimension names to Python ``slice`` objects indicating
where the data should be written (in index space, not coordinate space), e.g.,

.. ipython:: python

# For convenience, we'll slice a single dataset, but in the real use-case
# we would create them separately possibly even from separate processes.
ds = xr.Dataset({"foo": ("x", np.arange(30))})
ds.isel(x=slice(0, 10)).to_zarr(path, region={"x": slice(0, 10)})
ds.isel(x=slice(10, 20)).to_zarr(path, region={"x": slice(10, 20)})
# Any of the following region specifications are valid
ds.isel(x=slice(0, 10)).to_zarr(path, region="auto")
ds.isel(x=slice(10, 20)).to_zarr(path, region={"x": "auto"})
ds.isel(x=slice(20, 30)).to_zarr(path, region={"x": slice(20, 30)})

Concurrent writes with ``region`` are safe as long as they modify distinct
Expand Down
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ New Features
By `Deepak Cherian <https://github.com/dcherian>`_. (:issue:`7764`, :pull:`8373`).
- Add ``DataArray.dt.total_seconds()`` method to match the Pandas API. (:pull:`8435`).
By `Ben Mares <https://github.com/maresb>`_.
- Allow passing ``region="auto"`` in :py:meth:`Dataset.to_zarr` to automatically infer the
region to write in the original store. Also implement automatic transpose when dimension
order does not match the original store. (:issue:`7702`, :issue:`8421`, :pull:`8434`).
By `Sam Levang <https://github.com/slevang>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
65 changes: 60 additions & 5 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_normalize_path,
)
from xarray.backends.locks import _get_scheduler
from xarray.backends.zarr import open_zarr
from xarray.core import indexing
from xarray.core.combine import (
_infer_concat_order_from_positions,
Expand Down Expand Up @@ -1446,10 +1447,54 @@ def save_mfdataset(
)


def _validate_region(ds, region):
def _auto_detect_region(ds_new, ds_orig, dim):
# Create a mapping array of coordinates to indices on the original array
coord = ds_orig[dim]
da_map = DataArray(np.arange(coord.size), coords={dim: coord})

try:
da_idxs = da_map.sel({dim: ds_new[dim]})
except KeyError as e:
if "not all values found" in str(e):
raise KeyError(
f"Not all values of coordinate '{dim}' in the new array were"
" found in the original store. Writing to a zarr region slice"
" requires that no dimensions or metadata are changed by the write."
)
else:
raise e

if (da_idxs.diff(dim) != 1).any():
raise ValueError(
f"The auto-detected region of coordinate '{dim}' for writing new data"
" to the original store had non-contiguous indices. Writing to a zarr"
" region slice requires that the new data constitute a contiguous subset"
" of the original store."
)

dim_slice = slice(da_idxs.values[0], da_idxs.values[-1] + 1)

return dim_slice


def _auto_detect_regions(ds, region, open_kwargs):
ds_original = open_zarr(**open_kwargs)
for key, val in region.items():
if val == "auto":
region[key] = _auto_detect_region(ds, ds_original, key)
return region


def _validate_and_autodetect_region(ds, region, open_kwargs) -> dict[str, slice]:
if region == "auto":
region = {dim: "auto" for dim in ds.dims}

if not isinstance(region, dict):
raise TypeError(f"``region`` must be a dict, got {type(region)}")

if any(v == "auto" for v in region.values()):
region = _auto_detect_regions(ds, region, open_kwargs)

for k, v in region.items():
if k not in ds.dims:
raise ValueError(
Expand Down Expand Up @@ -1481,6 +1526,8 @@ def _validate_region(ds, region):
f".drop_vars({non_matching_vars!r})"
)

return region


def _validate_datatypes_for_zarr_append(zstore, dataset):
"""If variable exists in the store, confirm dtype of the data to append is compatible with
Expand Down Expand Up @@ -1532,7 +1579,7 @@ def to_zarr(
compute: Literal[True] = True,
consolidated: bool | None = None,
append_dim: Hashable | None = None,
region: Mapping[str, slice] | None = None,
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
safe_chunks: bool = True,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
Expand All @@ -1556,7 +1603,7 @@ def to_zarr(
compute: Literal[False],
consolidated: bool | None = None,
append_dim: Hashable | None = None,
region: Mapping[str, slice] | None = None,
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
safe_chunks: bool = True,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
Expand All @@ -1578,7 +1625,7 @@ def to_zarr(
compute: bool = True,
consolidated: bool | None = None,
append_dim: Hashable | None = None,
region: Mapping[str, slice] | None = None,
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
safe_chunks: bool = True,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
Expand Down Expand Up @@ -1643,7 +1690,15 @@ def to_zarr(
_validate_dataset_names(dataset)

if region is not None:
_validate_region(dataset, region)
open_kwargs = dict(
store=store,
synchronizer=synchronizer,
group=group,
consolidated=consolidated,
storage_options=storage_options,
zarr_version=zarr_version,
)
region = _validate_and_autodetect_region(dataset, region, open_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhere around here might be the right place to drop the dimension coordinates for region='auto'.

if append_dim is not None and append_dim in region:
raise ValueError(
f"cannot list the same dimension in both ``append_dim`` and "
Expand Down
23 changes: 15 additions & 8 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,19 @@ def encode_zarr_variable(var, needs_copy=True, name=None):
return var


def _validate_existing_dims(var_name, new_var, existing_var, region, append_dim):
def _validate_and_transpose_existing_dims(
var_name, new_var, existing_var, region, append_dim
):
if new_var.dims != existing_var.dims:
raise ValueError(
f"variable {var_name!r} already exists with different "
f"dimension names {existing_var.dims} != "
f"{new_var.dims}, but changing variable "
f"dimensions is not supported by to_zarr()."
)
if set(existing_var.dims) == set(new_var.dims):
new_var = new_var.transpose(*existing_var.dims)
else:
raise ValueError(
f"variable {var_name!r} already exists with different "
f"dimension names {existing_var.dims} != "
f"{new_var.dims}, but changing variable "
f"dimensions is not supported by to_zarr()."
)

existing_sizes = {}
for dim, size in existing_var.sizes.items():
Expand All @@ -347,6 +352,8 @@ def _validate_existing_dims(var_name, new_var, existing_var, region, append_dim)
f"explicitly appending, but append_dim={append_dim!r}."
)

return new_var


def _put_attrs(zarr_obj, attrs):
"""Raise a more informative error message for invalid attrs."""
Expand Down Expand Up @@ -616,7 +623,7 @@ def store(
for var_name in existing_variable_names:
new_var = variables_encoded[var_name]
existing_var = existing_vars[var_name]
_validate_existing_dims(
new_var = _validate_and_transpose_existing_dims(
var_name,
new_var,
existing_var,
Expand Down
14 changes: 10 additions & 4 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2305,7 +2305,7 @@ def to_zarr(
compute: Literal[True] = True,
consolidated: bool | None = None,
append_dim: Hashable | None = None,
region: Mapping[str, slice] | None = None,
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
safe_chunks: bool = True,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
Expand All @@ -2328,7 +2328,7 @@ def to_zarr(
compute: Literal[False],
consolidated: bool | None = None,
append_dim: Hashable | None = None,
region: Mapping[str, slice] | None = None,
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
safe_chunks: bool = True,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
Expand All @@ -2349,7 +2349,7 @@ def to_zarr(
compute: bool = True,
consolidated: bool | None = None,
append_dim: Hashable | None = None,
region: Mapping[str, slice] | None = None,
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
safe_chunks: bool = True,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
Expand Down Expand Up @@ -2411,14 +2411,20 @@ def to_zarr(
append_dim : hashable, optional
If set, the dimension along which the data will be appended. All
other dimensions on overridden variables must remain the same size.
region : dict, optional
region : dict or "auto", optional
Optional mapping from dimension names to integer slices along
dataset dimensions to indicate the region of existing zarr array(s)
in which to write this dataset's data. For example,
``{'x': slice(0, 1000), 'y': slice(10000, 11000)}`` would indicate
that values should be written to the region ``0:1000`` along ``x``
and ``10000:11000`` along ``y``.

Can also specify ``"auto"``, in which case the existing store will be
opened and the region inferred by matching the new data's coordinates.
``"auto"`` can be used as a single string, which will automatically infer
the region for all dimensions, or as dictionary values for specific
dimensions mixed together with explicit slices for other dimensions.

Two restrictions apply to the use of ``region``:

- If ``region`` is set, _all_ variables in a dataset must have at
Expand Down
125 changes: 125 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -5431,3 +5431,128 @@ def test_raise_writing_to_nczarr(self, mode) -> None:
def test_pickle_open_mfdataset_dataset():
ds = open_example_mfdataset(["bears.nc"])
assert_identical(ds, pickle.loads(pickle.dumps(ds)))


@requires_zarr
class TestZarrRegionAuto:
def test_zarr_region_auto_all(self, tmp_path):
x = np.arange(0, 50, 10)
y = np.arange(0, 20, 2)
data = np.ones((5, 10))
ds = xr.Dataset(
{
"test": xr.DataArray(
data,
dims=("x", "y"),
coords={"x": x, "y": y},
)
}
)
ds.to_zarr(tmp_path / "test.zarr")

ds_region = 1 + ds.isel(x=slice(2, 4), y=slice(6, 8))
ds_region.to_zarr(tmp_path / "test.zarr", region="auto")

ds_updated = xr.open_zarr(tmp_path / "test.zarr")

expected = ds.copy()
expected["test"][2:4, 6:8] += 1
assert_identical(ds_updated, expected)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be the place to add a test to verify that dimension coordinates are not being overwritten. To accomplish this, you could use a similar pattern to what @olimcc did in #8297: patch the store object so that it counts how many times a certain method has been called

with (
self.create_zarr_target() as store,
patch.object(
Group, "__getitem__", side_effect=Group.__getitem__, autospec=True
) as mock,
):
ds.to_zarr(store, mode="w")
# We expect this to request array metadata information, so call_count should be >= 1,
# At time of writing, 2 calls are made
xrds = xr.open_zarr(store)
call_count = mock.call_count
assert call_count > 0
# compute() requests array data, which should not trigger additional metadata requests
# we assert that the number of calls has not increased after fetchhing the array
xrds.test.compute(scheduler="sync")
assert mock.call_count == call_count

Here you would want to patch __setitem__ rather than __getitem__.


def test_zarr_region_auto_mixed(self, tmp_path):
x = np.arange(0, 50, 10)
y = np.arange(0, 20, 2)
data = np.ones((5, 10))
ds = xr.Dataset(
{
"test": xr.DataArray(
data,
dims=("x", "y"),
coords={"x": x, "y": y},
)
}
)
ds.to_zarr(tmp_path / "test.zarr")

ds_region = 1 + ds.isel(x=slice(2, 4), y=slice(6, 8))
ds_region.to_zarr(
tmp_path / "test.zarr", region={"x": "auto", "y": slice(6, 8)}
)

ds_updated = xr.open_zarr(tmp_path / "test.zarr")

expected = ds.copy()
expected["test"][2:4, 6:8] += 1
assert_identical(ds_updated, expected)

def test_zarr_region_auto_noncontiguous(self, tmp_path):
x = np.arange(0, 50, 10)
y = np.arange(0, 20, 2)
data = np.ones((5, 10))
ds = xr.Dataset(
{
"test": xr.DataArray(
data,
dims=("x", "y"),
coords={"x": x, "y": y},
)
}
)
ds.to_zarr(tmp_path / "test.zarr")

ds_region = 1 + ds.isel(x=[0, 2, 3], y=[5, 6])
with pytest.raises(ValueError):
ds_region.to_zarr(tmp_path / "test.zarr", region={"x": "auto", "y": "auto"})

def test_zarr_region_auto_new_coord_vals(self, tmp_path):
x = np.arange(0, 50, 10)
y = np.arange(0, 20, 2)
data = np.ones((5, 10))
ds = xr.Dataset(
{
"test": xr.DataArray(
data,
dims=("x", "y"),
coords={"x": x, "y": y},
)
}
)
ds.to_zarr(tmp_path / "test.zarr")

x = np.arange(5, 55, 10)
y = np.arange(0, 20, 2)
data = np.ones((5, 10))
ds = xr.Dataset(
{
"test": xr.DataArray(
data,
dims=("x", "y"),
coords={"x": x, "y": y},
)
}
)

ds_region = 1 + ds.isel(x=slice(2, 4), y=slice(6, 8))
with pytest.raises(KeyError):
ds_region.to_zarr(tmp_path / "test.zarr", region={"x": "auto", "y": "auto"})


def test_zarr_region_transpose(tmp_path):
x = np.arange(0, 50, 10)
y = np.arange(0, 20, 2)
data = np.ones((5, 10))
ds = xr.Dataset(
{
"test": xr.DataArray(
data,
dims=("x", "y"),
coords={"x": x, "y": y},
)
}
)
ds.to_zarr(tmp_path / "test.zarr")

ds_region = 1 + ds.isel(x=[0], y=[0]).transpose()
ds_region.to_zarr(
tmp_path / "test.zarr", region={"x": slice(0, 1), "y": slice(0, 1)}
)