Skip to content

Commit 2a6212e

Browse files
josephnowakpre-commit-ci[bot]max-sixty
authored
Improve safe chunk validation (#9527)
* fix safe chunks validation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix safe chunks validation * Update xarray/tests/test_backends.py Co-authored-by: Maximilian Roos <[email protected]> * The validation of the chunks now is able to detect full or partial chunk and raise a proper error based on the mode selected, it is also possible to use the auto region detection with the mode "a" * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * The test_extract_zarr_variable_encoding does not need to use the region parameter * Inline the code of the allow_partial_chunks and end, document the parameter in order on the extract_zarr_variable_encoding method, raise the correct error if the border size is smaller than the zchunk on mode equal to r+ * Inline the code of the allow_partial_chunks and end, document the parameter in order on the extract_zarr_variable_encoding method, raise the correct error if the border size is smaller than the zchunk on mode equal to r+ * Now the mode r+ is able to update the last chunk of Zarr even if it is not "complete" * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Now the mode r+ is able to update the last chunk of Zarr even if it is not "complete" * Add a typehint to the modes to avoid issues with mypy --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Maximilian Roos <[email protected]>
1 parent e649e13 commit 2a6212e

File tree

5 files changed

+303
-54
lines changed

5 files changed

+303
-54
lines changed

doc/whats-new.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ Bug fixes
5858
<https://github.com/spencerkclark>`_.
5959
- Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`).
6060
By `Deepak Cherian <https://github.com/dcherian>`_.
61-
61+
- Fix the safe_chunks validation option on the to_zarr method
62+
(:issue:`5511`, :pull:`9513`). By `Joseph Nowak
63+
<https://github.com/josephnowak>`_.
6264

6365
Documentation
6466
~~~~~~~~~~~~~

xarray/backends/zarr.py

+121-47
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,9 @@ def __getitem__(self, key):
112112
# could possibly have a work-around for 0d data here
113113

114114

115-
def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
115+
def _determine_zarr_chunks(
116+
enc_chunks, var_chunks, ndim, name, safe_chunks, region, mode, shape
117+
):
116118
"""
117119
Given encoding chunks (possibly None or []) and variable chunks
118120
(possibly None or []).
@@ -163,7 +165,9 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
163165

164166
if len(enc_chunks_tuple) != ndim:
165167
# throw away encoding chunks, start over
166-
return _determine_zarr_chunks(None, var_chunks, ndim, name, safe_chunks)
168+
return _determine_zarr_chunks(
169+
None, var_chunks, ndim, name, safe_chunks, region, mode, shape
170+
)
167171

168172
for x in enc_chunks_tuple:
169173
if not isinstance(x, int):
@@ -189,20 +193,59 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
189193
# TODO: incorporate synchronizer to allow writes from multiple dask
190194
# threads
191195
if var_chunks and enc_chunks_tuple:
192-
for zchunk, dchunks in zip(enc_chunks_tuple, var_chunks, strict=True):
193-
for dchunk in dchunks[:-1]:
196+
# If it is possible to write on partial chunks then it is not necessary to check
197+
# the last one contained on the region
198+
allow_partial_chunks = mode != "r+"
199+
200+
base_error = (
201+
f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for "
202+
f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r} "
203+
f"on the region {region}. "
204+
f"Writing this array in parallel with dask could lead to corrupted data."
205+
f"Consider either rechunking using `chunk()`, deleting "
206+
f"or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
207+
)
208+
209+
for zchunk, dchunks, interval, size in zip(
210+
enc_chunks_tuple, var_chunks, region, shape, strict=True
211+
):
212+
if not safe_chunks:
213+
continue
214+
215+
for dchunk in dchunks[1:-1]:
194216
if dchunk % zchunk:
195-
base_error = (
196-
f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for "
197-
f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r}. "
198-
f"Writing this array in parallel with dask could lead to corrupted data."
199-
)
200-
if safe_chunks:
201-
raise ValueError(
202-
base_error
203-
+ " Consider either rechunking using `chunk()`, deleting "
204-
"or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
205-
)
217+
raise ValueError(base_error)
218+
219+
region_start = interval.start if interval.start else 0
220+
221+
if len(dchunks) > 1:
222+
# The first border size is the amount of data that needs to be updated on the
223+
# first chunk taking into account the region slice.
224+
first_border_size = zchunk
225+
if allow_partial_chunks:
226+
first_border_size = zchunk - region_start % zchunk
227+
228+
if (dchunks[0] - first_border_size) % zchunk:
229+
raise ValueError(base_error)
230+
231+
if not allow_partial_chunks:
232+
chunk_start = sum(dchunks[:-1]) + region_start
233+
if chunk_start % zchunk:
234+
# The last chunk which can also be the only one is a partial chunk
235+
# if it is not aligned at the beginning
236+
raise ValueError(base_error)
237+
238+
region_stop = interval.stop if interval.stop else size
239+
240+
if size - region_stop + 1 < zchunk:
241+
# If the region is covering the last chunk then check
242+
# if the reminder with the default chunk size
243+
# is equal to the size of the last chunk
244+
if dchunks[-1] % zchunk != size % zchunk:
245+
raise ValueError(base_error)
246+
elif dchunks[-1] % zchunk:
247+
raise ValueError(base_error)
248+
206249
return enc_chunks_tuple
207250

208251
raise AssertionError("We should never get here. Function logic must be wrong.")
@@ -243,7 +286,14 @@ def _get_zarr_dims_and_attrs(zarr_obj, dimension_key, try_nczarr):
243286

244287

245288
def extract_zarr_variable_encoding(
246-
variable, raise_on_invalid=False, name=None, safe_chunks=True
289+
variable,
290+
raise_on_invalid=False,
291+
name=None,
292+
*,
293+
safe_chunks=True,
294+
region=None,
295+
mode=None,
296+
shape=None,
247297
):
248298
"""
249299
Extract zarr encoding dictionary from xarray Variable
@@ -252,12 +302,18 @@ def extract_zarr_variable_encoding(
252302
----------
253303
variable : Variable
254304
raise_on_invalid : bool, optional
255-
305+
name: str | Hashable, optional
306+
safe_chunks: bool, optional
307+
region: tuple[slice, ...], optional
308+
mode: str, optional
309+
shape: tuple[int, ...], optional
256310
Returns
257311
-------
258312
encoding : dict
259313
Zarr encoding for `variable`
260314
"""
315+
316+
shape = shape if shape else variable.shape
261317
encoding = variable.encoding.copy()
262318

263319
safe_to_drop = {"source", "original_shape"}
@@ -285,7 +341,14 @@ def extract_zarr_variable_encoding(
285341
del encoding[k]
286342

287343
chunks = _determine_zarr_chunks(
288-
encoding.get("chunks"), variable.chunks, variable.ndim, name, safe_chunks
344+
enc_chunks=encoding.get("chunks"),
345+
var_chunks=variable.chunks,
346+
ndim=variable.ndim,
347+
name=name,
348+
safe_chunks=safe_chunks,
349+
region=region,
350+
mode=mode,
351+
shape=shape,
289352
)
290353
encoding["chunks"] = chunks
291354
return encoding
@@ -762,16 +825,10 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
762825
if v.encoding == {"_FillValue": None} and fill_value is None:
763826
v.encoding = {}
764827

765-
# We need to do this for both new and existing variables to ensure we're not
766-
# writing to a partial chunk, even though we don't use the `encoding` value
767-
# when writing to an existing variable. See
768-
# https://github.com/pydata/xarray/issues/8371 for details.
769-
encoding = extract_zarr_variable_encoding(
770-
v,
771-
raise_on_invalid=vn in check_encoding_set,
772-
name=vn,
773-
safe_chunks=self._safe_chunks,
774-
)
828+
zarr_array = None
829+
zarr_shape = None
830+
write_region = self._write_region if self._write_region is not None else {}
831+
write_region = {dim: write_region.get(dim, slice(None)) for dim in dims}
775832

776833
if name in existing_keys:
777834
# existing variable
@@ -801,7 +858,40 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
801858
)
802859
else:
803860
zarr_array = self.zarr_group[name]
804-
else:
861+
862+
if self._append_dim is not None and self._append_dim in dims:
863+
# resize existing variable
864+
append_axis = dims.index(self._append_dim)
865+
assert write_region[self._append_dim] == slice(None)
866+
write_region[self._append_dim] = slice(
867+
zarr_array.shape[append_axis], None
868+
)
869+
870+
new_shape = list(zarr_array.shape)
871+
new_shape[append_axis] += v.shape[append_axis]
872+
zarr_array.resize(new_shape)
873+
874+
zarr_shape = zarr_array.shape
875+
876+
region = tuple(write_region[dim] for dim in dims)
877+
878+
# We need to do this for both new and existing variables to ensure we're not
879+
# writing to a partial chunk, even though we don't use the `encoding` value
880+
# when writing to an existing variable. See
881+
# https://github.com/pydata/xarray/issues/8371 for details.
882+
# Note: Ideally there should be two functions, one for validating the chunks and
883+
# another one for extracting the encoding.
884+
encoding = extract_zarr_variable_encoding(
885+
v,
886+
raise_on_invalid=vn in check_encoding_set,
887+
name=vn,
888+
safe_chunks=self._safe_chunks,
889+
region=region,
890+
mode=self._mode,
891+
shape=zarr_shape,
892+
)
893+
894+
if name not in existing_keys:
805895
# new variable
806896
encoded_attrs = {}
807897
# the magic for storing the hidden dimension data
@@ -833,22 +923,6 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
833923
)
834924
zarr_array = _put_attrs(zarr_array, encoded_attrs)
835925

836-
write_region = self._write_region if self._write_region is not None else {}
837-
write_region = {dim: write_region.get(dim, slice(None)) for dim in dims}
838-
839-
if self._append_dim is not None and self._append_dim in dims:
840-
# resize existing variable
841-
append_axis = dims.index(self._append_dim)
842-
assert write_region[self._append_dim] == slice(None)
843-
write_region[self._append_dim] = slice(
844-
zarr_array.shape[append_axis], None
845-
)
846-
847-
new_shape = list(zarr_array.shape)
848-
new_shape[append_axis] += v.shape[append_axis]
849-
zarr_array.resize(new_shape)
850-
851-
region = tuple(write_region[dim] for dim in dims)
852926
writer.add(v.data, zarr_array, region)
853927

854928
def close(self) -> None:
@@ -897,9 +971,9 @@ def _validate_and_autodetect_region(self, ds) -> None:
897971
if not isinstance(region, dict):
898972
raise TypeError(f"``region`` must be a dict, got {type(region)}")
899973
if any(v == "auto" for v in region.values()):
900-
if self._mode != "r+":
974+
if self._mode not in ["r+", "a"]:
901975
raise ValueError(
902-
f"``mode`` must be 'r+' when using ``region='auto'``, got {self._mode!r}"
976+
f"``mode`` must be 'r+' or 'a' when using ``region='auto'``, got {self._mode!r}"
903977
)
904978
region = self._auto_detect_regions(ds, region)
905979

xarray/core/dataarray.py

+8
Original file line numberDiff line numberDiff line change
@@ -4316,6 +4316,14 @@ def to_zarr(
43164316
if Zarr arrays are written in parallel. This option may be useful in combination
43174317
with ``compute=False`` to initialize a Zarr store from an existing
43184318
DataArray with arbitrary chunk structure.
4319+
In addition to the many-to-one relationship validation, it also detects partial
4320+
chunks writes when using the region parameter,
4321+
these partial chunks are considered unsafe in the mode "r+" but safe in
4322+
the mode "a".
4323+
Note: Even with these validations it can still be unsafe to write
4324+
two or more chunked arrays in the same location in parallel if they are
4325+
not writing in independent regions, for those cases it is better to use
4326+
a synchronizer.
43194327
storage_options : dict, optional
43204328
Any additional parameters for the storage backend (ignored for local
43214329
paths).

xarray/core/dataset.py

+8
Original file line numberDiff line numberDiff line change
@@ -2509,6 +2509,14 @@ def to_zarr(
25092509
if Zarr arrays are written in parallel. This option may be useful in combination
25102510
with ``compute=False`` to initialize a Zarr from an existing
25112511
Dataset with arbitrary chunk structure.
2512+
In addition to the many-to-one relationship validation, it also detects partial
2513+
chunks writes when using the region parameter,
2514+
these partial chunks are considered unsafe in the mode "r+" but safe in
2515+
the mode "a".
2516+
Note: Even with these validations it can still be unsafe to write
2517+
two or more chunked arrays in the same location in parallel if they are
2518+
not writing in independent regions, for those cases it is better to use
2519+
a synchronizer.
25122520
storage_options : dict, optional
25132521
Any additional parameters for the storage backend (ignored for local
25142522
paths).

0 commit comments

Comments
 (0)