Skip to content

Commit 617e49a

Browse files
authored
Revert "Improve safe chunk validation (#9527)"
This reverts commit 2a6212e.
1 parent ece582d commit 617e49a

File tree

5 files changed

+54
-303
lines changed

5 files changed

+54
-303
lines changed

doc/whats-new.rst

+1-3
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,7 @@ Bug fixes
6161
<https://github.com/spencerkclark>`_.
6262
- Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`).
6363
By `Deepak Cherian <https://github.com/dcherian>`_.
64-
- Fix the safe_chunks validation option on the to_zarr method
65-
(:issue:`5511`, :pull:`9513`). By `Joseph Nowak
66-
<https://github.com/josephnowak>`_.
64+
6765

6866
Documentation
6967
~~~~~~~~~~~~~

xarray/backends/zarr.py

+47-121
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,7 @@ def __getitem__(self, key):
112112
# could possibly have a work-around for 0d data here
113113

114114

115-
def _determine_zarr_chunks(
116-
enc_chunks, var_chunks, ndim, name, safe_chunks, region, mode, shape
117-
):
115+
def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
118116
"""
119117
Given encoding chunks (possibly None or []) and variable chunks
120118
(possibly None or []).
@@ -165,9 +163,7 @@ def _determine_zarr_chunks(
165163

166164
if len(enc_chunks_tuple) != ndim:
167165
# throw away encoding chunks, start over
168-
return _determine_zarr_chunks(
169-
None, var_chunks, ndim, name, safe_chunks, region, mode, shape
170-
)
166+
return _determine_zarr_chunks(None, var_chunks, ndim, name, safe_chunks)
171167

172168
for x in enc_chunks_tuple:
173169
if not isinstance(x, int):
@@ -193,59 +189,20 @@ def _determine_zarr_chunks(
193189
# TODO: incorporate synchronizer to allow writes from multiple dask
194190
# threads
195191
if var_chunks and enc_chunks_tuple:
196-
# If it is possible to write on partial chunks then it is not necessary to check
197-
# the last one contained on the region
198-
allow_partial_chunks = mode != "r+"
199-
200-
base_error = (
201-
f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for "
202-
f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r} "
203-
f"on the region {region}. "
204-
f"Writing this array in parallel with dask could lead to corrupted data."
205-
f"Consider either rechunking using `chunk()`, deleting "
206-
f"or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
207-
)
208-
209-
for zchunk, dchunks, interval, size in zip(
210-
enc_chunks_tuple, var_chunks, region, shape, strict=True
211-
):
212-
if not safe_chunks:
213-
continue
214-
215-
for dchunk in dchunks[1:-1]:
192+
for zchunk, dchunks in zip(enc_chunks_tuple, var_chunks, strict=True):
193+
for dchunk in dchunks[:-1]:
216194
if dchunk % zchunk:
217-
raise ValueError(base_error)
218-
219-
region_start = interval.start if interval.start else 0
220-
221-
if len(dchunks) > 1:
222-
# The first border size is the amount of data that needs to be updated on the
223-
# first chunk taking into account the region slice.
224-
first_border_size = zchunk
225-
if allow_partial_chunks:
226-
first_border_size = zchunk - region_start % zchunk
227-
228-
if (dchunks[0] - first_border_size) % zchunk:
229-
raise ValueError(base_error)
230-
231-
if not allow_partial_chunks:
232-
chunk_start = sum(dchunks[:-1]) + region_start
233-
if chunk_start % zchunk:
234-
# The last chunk which can also be the only one is a partial chunk
235-
# if it is not aligned at the beginning
236-
raise ValueError(base_error)
237-
238-
region_stop = interval.stop if interval.stop else size
239-
240-
if size - region_stop + 1 < zchunk:
241-
# If the region is covering the last chunk then check
242-
# if the reminder with the default chunk size
243-
# is equal to the size of the last chunk
244-
if dchunks[-1] % zchunk != size % zchunk:
245-
raise ValueError(base_error)
246-
elif dchunks[-1] % zchunk:
247-
raise ValueError(base_error)
248-
195+
base_error = (
196+
f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for "
197+
f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r}. "
198+
f"Writing this array in parallel with dask could lead to corrupted data."
199+
)
200+
if safe_chunks:
201+
raise ValueError(
202+
base_error
203+
+ " Consider either rechunking using `chunk()`, deleting "
204+
"or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
205+
)
249206
return enc_chunks_tuple
250207

251208
raise AssertionError("We should never get here. Function logic must be wrong.")
@@ -286,14 +243,7 @@ def _get_zarr_dims_and_attrs(zarr_obj, dimension_key, try_nczarr):
286243

287244

288245
def extract_zarr_variable_encoding(
289-
variable,
290-
raise_on_invalid=False,
291-
name=None,
292-
*,
293-
safe_chunks=True,
294-
region=None,
295-
mode=None,
296-
shape=None,
246+
variable, raise_on_invalid=False, name=None, safe_chunks=True
297247
):
298248
"""
299249
Extract zarr encoding dictionary from xarray Variable
@@ -302,18 +252,12 @@ def extract_zarr_variable_encoding(
302252
----------
303253
variable : Variable
304254
raise_on_invalid : bool, optional
305-
name: str | Hashable, optional
306-
safe_chunks: bool, optional
307-
region: tuple[slice, ...], optional
308-
mode: str, optional
309-
shape: tuple[int, ...], optional
255+
310256
Returns
311257
-------
312258
encoding : dict
313259
Zarr encoding for `variable`
314260
"""
315-
316-
shape = shape if shape else variable.shape
317261
encoding = variable.encoding.copy()
318262

319263
safe_to_drop = {"source", "original_shape"}
@@ -341,14 +285,7 @@ def extract_zarr_variable_encoding(
341285
del encoding[k]
342286

343287
chunks = _determine_zarr_chunks(
344-
enc_chunks=encoding.get("chunks"),
345-
var_chunks=variable.chunks,
346-
ndim=variable.ndim,
347-
name=name,
348-
safe_chunks=safe_chunks,
349-
region=region,
350-
mode=mode,
351-
shape=shape,
288+
encoding.get("chunks"), variable.chunks, variable.ndim, name, safe_chunks
352289
)
353290
encoding["chunks"] = chunks
354291
return encoding
@@ -825,10 +762,16 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
825762
if v.encoding == {"_FillValue": None} and fill_value is None:
826763
v.encoding = {}
827764

828-
zarr_array = None
829-
zarr_shape = None
830-
write_region = self._write_region if self._write_region is not None else {}
831-
write_region = {dim: write_region.get(dim, slice(None)) for dim in dims}
765+
# We need to do this for both new and existing variables to ensure we're not
766+
# writing to a partial chunk, even though we don't use the `encoding` value
767+
# when writing to an existing variable. See
768+
# https://github.com/pydata/xarray/issues/8371 for details.
769+
encoding = extract_zarr_variable_encoding(
770+
v,
771+
raise_on_invalid=vn in check_encoding_set,
772+
name=vn,
773+
safe_chunks=self._safe_chunks,
774+
)
832775

833776
if name in existing_keys:
834777
# existing variable
@@ -858,40 +801,7 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
858801
)
859802
else:
860803
zarr_array = self.zarr_group[name]
861-
862-
if self._append_dim is not None and self._append_dim in dims:
863-
# resize existing variable
864-
append_axis = dims.index(self._append_dim)
865-
assert write_region[self._append_dim] == slice(None)
866-
write_region[self._append_dim] = slice(
867-
zarr_array.shape[append_axis], None
868-
)
869-
870-
new_shape = list(zarr_array.shape)
871-
new_shape[append_axis] += v.shape[append_axis]
872-
zarr_array.resize(new_shape)
873-
874-
zarr_shape = zarr_array.shape
875-
876-
region = tuple(write_region[dim] for dim in dims)
877-
878-
# We need to do this for both new and existing variables to ensure we're not
879-
# writing to a partial chunk, even though we don't use the `encoding` value
880-
# when writing to an existing variable. See
881-
# https://github.com/pydata/xarray/issues/8371 for details.
882-
# Note: Ideally there should be two functions, one for validating the chunks and
883-
# another one for extracting the encoding.
884-
encoding = extract_zarr_variable_encoding(
885-
v,
886-
raise_on_invalid=vn in check_encoding_set,
887-
name=vn,
888-
safe_chunks=self._safe_chunks,
889-
region=region,
890-
mode=self._mode,
891-
shape=zarr_shape,
892-
)
893-
894-
if name not in existing_keys:
804+
else:
895805
# new variable
896806
encoded_attrs = {}
897807
# the magic for storing the hidden dimension data
@@ -923,6 +833,22 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
923833
)
924834
zarr_array = _put_attrs(zarr_array, encoded_attrs)
925835

836+
write_region = self._write_region if self._write_region is not None else {}
837+
write_region = {dim: write_region.get(dim, slice(None)) for dim in dims}
838+
839+
if self._append_dim is not None and self._append_dim in dims:
840+
# resize existing variable
841+
append_axis = dims.index(self._append_dim)
842+
assert write_region[self._append_dim] == slice(None)
843+
write_region[self._append_dim] = slice(
844+
zarr_array.shape[append_axis], None
845+
)
846+
847+
new_shape = list(zarr_array.shape)
848+
new_shape[append_axis] += v.shape[append_axis]
849+
zarr_array.resize(new_shape)
850+
851+
region = tuple(write_region[dim] for dim in dims)
926852
writer.add(v.data, zarr_array, region)
927853

928854
def close(self) -> None:
@@ -971,9 +897,9 @@ def _validate_and_autodetect_region(self, ds) -> None:
971897
if not isinstance(region, dict):
972898
raise TypeError(f"``region`` must be a dict, got {type(region)}")
973899
if any(v == "auto" for v in region.values()):
974-
if self._mode not in ["r+", "a"]:
900+
if self._mode != "r+":
975901
raise ValueError(
976-
f"``mode`` must be 'r+' or 'a' when using ``region='auto'``, got {self._mode!r}"
902+
f"``mode`` must be 'r+' when using ``region='auto'``, got {self._mode!r}"
977903
)
978904
region = self._auto_detect_regions(ds, region)
979905

xarray/core/dataarray.py

-8
Original file line numberDiff line numberDiff line change
@@ -4316,14 +4316,6 @@ def to_zarr(
43164316
if Zarr arrays are written in parallel. This option may be useful in combination
43174317
with ``compute=False`` to initialize a Zarr store from an existing
43184318
DataArray with arbitrary chunk structure.
4319-
In addition to the many-to-one relationship validation, it also detects partial
4320-
chunks writes when using the region parameter,
4321-
these partial chunks are considered unsafe in the mode "r+" but safe in
4322-
the mode "a".
4323-
Note: Even with these validations it can still be unsafe to write
4324-
two or more chunked arrays in the same location in parallel if they are
4325-
not writing in independent regions, for those cases it is better to use
4326-
a synchronizer.
43274319
storage_options : dict, optional
43284320
Any additional parameters for the storage backend (ignored for local
43294321
paths).

xarray/core/dataset.py

-8
Original file line numberDiff line numberDiff line change
@@ -2509,14 +2509,6 @@ def to_zarr(
25092509
if Zarr arrays are written in parallel. This option may be useful in combination
25102510
with ``compute=False`` to initialize a Zarr from an existing
25112511
Dataset with arbitrary chunk structure.
2512-
In addition to the many-to-one relationship validation, it also detects partial
2513-
chunks writes when using the region parameter,
2514-
these partial chunks are considered unsafe in the mode "r+" but safe in
2515-
the mode "a".
2516-
Note: Even with these validations it can still be unsafe to write
2517-
two or more chunked arrays in the same location in parallel if they are
2518-
not writing in independent regions, for those cases it is better to use
2519-
a synchronizer.
25202512
storage_options : dict, optional
25212513
Any additional parameters for the storage backend (ignored for local
25222514
paths).

0 commit comments

Comments
 (0)