Skip to content

Check for aligned chunks when writing to existing variables #8459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 29, 2024
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ Deprecations
Bug fixes
~~~~~~~~~

- Partial writes to existing chunks with ``region`` will now raise an error
(unless ``safe_chunks=False``); previously an error would only be raised on
new variables. (:pull:`8459`, :issue:`8371`)
By `Maximilian Roos <https://github.com/max-sixty>`_.
- Port `bug fix from pandas <https://github.com/pandas-dev/pandas/pull/55283>`_
to eliminate the adjustment of resample bin edges in the case that the
resampling frequency has units of days and is greater than one day
Expand Down
16 changes: 12 additions & 4 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
f"Writing this array in parallel with dask could lead to corrupted data."
)
if safe_chunks:
raise NotImplementedError(
raise ValueError(
base_error
+ " Consider either rechunking using `chunk()`, deleting "
"or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
Expand Down Expand Up @@ -679,6 +679,17 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
if v.encoding == {"_FillValue": None} and fill_value is None:
v.encoding = {}

# We need to do this for both new and existing variables to ensure we're not
# writing to a partial chunk, even though we don't use the `encoding` value
# when writing to an existing variable. See
# https://github.com/pydata/xarray/issues/8371 for details.
encoding = extract_zarr_variable_encoding(
v,
raise_on_invalid=check,
name=vn,
safe_chunks=self._safe_chunks,
)

if name in self.zarr_group:
# existing variable
# TODO: if mode="a", consider overriding the existing variable
Expand Down Expand Up @@ -709,9 +720,6 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
zarr_array = self.zarr_group[name]
else:
# new variable
encoding = extract_zarr_variable_encoding(
v, raise_on_invalid=check, name=vn, safe_chunks=self._safe_chunks
)
encoded_attrs = {}
# the magic for storing the hidden dimension data
encoded_attrs[DIMENSION_KEY] = dims
Expand Down
28 changes: 23 additions & 5 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2100,7 +2100,7 @@ def test_chunk_encoding_with_dask(self) -> None:
# should fail if encoding["chunks"] clashes with dask_chunks
badenc = ds.chunk({"x": 4})
badenc.var1.encoding["chunks"] = (6,)
with pytest.raises(NotImplementedError, match=r"named 'var1' would overlap"):
with pytest.raises(ValueError, match=r"named 'var1' would overlap"):
with self.roundtrip(badenc) as actual:
pass

Expand Down Expand Up @@ -2138,9 +2138,7 @@ def test_chunk_encoding_with_dask(self) -> None:
# but itermediate unaligned chunks are bad
badenc = ds.chunk({"x": (3, 5, 3, 1)})
badenc.var1.encoding["chunks"] = (3,)
with pytest.raises(
NotImplementedError, match=r"would overlap multiple dask chunks"
):
with pytest.raises(ValueError, match=r"would overlap multiple dask chunks"):
with self.roundtrip(badenc) as actual:
pass

Expand All @@ -2154,7 +2152,7 @@ def test_chunk_encoding_with_dask(self) -> None:
# TODO: remove this failure once synchronized overlapping writes are
# supported by xarray
ds_chunk4["var1"].encoding.update({"chunks": 5})
with pytest.raises(NotImplementedError, match=r"named 'var1' would overlap"):
with pytest.raises(ValueError, match=r"named 'var1' would overlap"):
with self.roundtrip(ds_chunk4) as actual:
pass
# override option
Expand Down Expand Up @@ -5442,3 +5440,23 @@ def test_zarr_region_transpose(tmp_path):
ds_region.to_zarr(
tmp_path / "test.zarr", region={"x": slice(0, 1), "y": slice(0, 1)}
)


@requires_zarr
@requires_dask
def test_zarr_region_chunk_partial(tmp_path):
"""
Check that writing to partial chunks with `region` fails, assuming `safe_chunks=False`.
"""
ds = (
xr.DataArray(np.arange(120).reshape(4, 3, -1), dims=list("abc"))
.rename("var1")
.to_dataset()
)

ds.chunk(5).to_zarr(tmp_path / "foo.zarr", compute=False, mode="w")
with pytest.raises(ValueError):
for r in range(ds.sizes["a"]):
ds.chunk(3).isel(a=[r]).to_zarr(
tmp_path / "foo.zarr", region=dict(a=slice(r, r + 1))
)