@@ -112,9 +112,7 @@ def __getitem__(self, key):
112
112
# could possibly have a work-around for 0d data here
113
113
114
114
115
- def _determine_zarr_chunks (
116
- enc_chunks , var_chunks , ndim , name , safe_chunks , region , mode , shape
117
- ):
115
+ def _determine_zarr_chunks (enc_chunks , var_chunks , ndim , name , safe_chunks ):
118
116
"""
119
117
Given encoding chunks (possibly None or []) and variable chunks
120
118
(possibly None or []).
@@ -165,9 +163,7 @@ def _determine_zarr_chunks(
165
163
166
164
if len (enc_chunks_tuple ) != ndim :
167
165
# throw away encoding chunks, start over
168
- return _determine_zarr_chunks (
169
- None , var_chunks , ndim , name , safe_chunks , region , mode , shape
170
- )
166
+ return _determine_zarr_chunks (None , var_chunks , ndim , name , safe_chunks )
171
167
172
168
for x in enc_chunks_tuple :
173
169
if not isinstance (x , int ):
@@ -193,59 +189,20 @@ def _determine_zarr_chunks(
193
189
# TODO: incorporate synchronizer to allow writes from multiple dask
194
190
# threads
195
191
if var_chunks and enc_chunks_tuple :
196
- # If it is possible to write on partial chunks then it is not necessary to check
197
- # the last one contained on the region
198
- allow_partial_chunks = mode != "r+"
199
-
200
- base_error = (
201
- f"Specified zarr chunks encoding['chunks']={ enc_chunks_tuple !r} for "
202
- f"variable named { name !r} would overlap multiple dask chunks { var_chunks !r} "
203
- f"on the region { region } . "
204
- f"Writing this array in parallel with dask could lead to corrupted data."
205
- f"Consider either rechunking using `chunk()`, deleting "
206
- f"or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
207
- )
208
-
209
- for zchunk , dchunks , interval , size in zip (
210
- enc_chunks_tuple , var_chunks , region , shape , strict = True
211
- ):
212
- if not safe_chunks :
213
- continue
214
-
215
- for dchunk in dchunks [1 :- 1 ]:
192
+ for zchunk , dchunks in zip (enc_chunks_tuple , var_chunks , strict = True ):
193
+ for dchunk in dchunks [:- 1 ]:
216
194
if dchunk % zchunk :
217
- raise ValueError (base_error )
218
-
219
- region_start = interval .start if interval .start else 0
220
-
221
- if len (dchunks ) > 1 :
222
- # The first border size is the amount of data that needs to be updated on the
223
- # first chunk taking into account the region slice.
224
- first_border_size = zchunk
225
- if allow_partial_chunks :
226
- first_border_size = zchunk - region_start % zchunk
227
-
228
- if (dchunks [0 ] - first_border_size ) % zchunk :
229
- raise ValueError (base_error )
230
-
231
- if not allow_partial_chunks :
232
- chunk_start = sum (dchunks [:- 1 ]) + region_start
233
- if chunk_start % zchunk :
234
- # The last chunk which can also be the only one is a partial chunk
235
- # if it is not aligned at the beginning
236
- raise ValueError (base_error )
237
-
238
- region_stop = interval .stop if interval .stop else size
239
-
240
- if size - region_stop + 1 < zchunk :
241
- # If the region is covering the last chunk then check
242
- # if the reminder with the default chunk size
243
- # is equal to the size of the last chunk
244
- if dchunks [- 1 ] % zchunk != size % zchunk :
245
- raise ValueError (base_error )
246
- elif dchunks [- 1 ] % zchunk :
247
- raise ValueError (base_error )
248
-
195
+ base_error = (
196
+ f"Specified zarr chunks encoding['chunks']={ enc_chunks_tuple !r} for "
197
+ f"variable named { name !r} would overlap multiple dask chunks { var_chunks !r} . "
198
+ f"Writing this array in parallel with dask could lead to corrupted data."
199
+ )
200
+ if safe_chunks :
201
+ raise ValueError (
202
+ base_error
203
+ + " Consider either rechunking using `chunk()`, deleting "
204
+ "or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
205
+ )
249
206
return enc_chunks_tuple
250
207
251
208
raise AssertionError ("We should never get here. Function logic must be wrong." )
@@ -286,14 +243,7 @@ def _get_zarr_dims_and_attrs(zarr_obj, dimension_key, try_nczarr):
286
243
287
244
288
245
def extract_zarr_variable_encoding (
289
- variable ,
290
- raise_on_invalid = False ,
291
- name = None ,
292
- * ,
293
- safe_chunks = True ,
294
- region = None ,
295
- mode = None ,
296
- shape = None ,
246
+ variable , raise_on_invalid = False , name = None , safe_chunks = True
297
247
):
298
248
"""
299
249
Extract zarr encoding dictionary from xarray Variable
@@ -302,18 +252,12 @@ def extract_zarr_variable_encoding(
302
252
----------
303
253
variable : Variable
304
254
raise_on_invalid : bool, optional
305
- name: str | Hashable, optional
306
- safe_chunks: bool, optional
307
- region: tuple[slice, ...], optional
308
- mode: str, optional
309
- shape: tuple[int, ...], optional
255
+
310
256
Returns
311
257
-------
312
258
encoding : dict
313
259
Zarr encoding for `variable`
314
260
"""
315
-
316
- shape = shape if shape else variable .shape
317
261
encoding = variable .encoding .copy ()
318
262
319
263
safe_to_drop = {"source" , "original_shape" }
@@ -341,14 +285,7 @@ def extract_zarr_variable_encoding(
341
285
del encoding [k ]
342
286
343
287
chunks = _determine_zarr_chunks (
344
- enc_chunks = encoding .get ("chunks" ),
345
- var_chunks = variable .chunks ,
346
- ndim = variable .ndim ,
347
- name = name ,
348
- safe_chunks = safe_chunks ,
349
- region = region ,
350
- mode = mode ,
351
- shape = shape ,
288
+ encoding .get ("chunks" ), variable .chunks , variable .ndim , name , safe_chunks
352
289
)
353
290
encoding ["chunks" ] = chunks
354
291
return encoding
@@ -825,10 +762,16 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
825
762
if v .encoding == {"_FillValue" : None } and fill_value is None :
826
763
v .encoding = {}
827
764
828
- zarr_array = None
829
- zarr_shape = None
830
- write_region = self ._write_region if self ._write_region is not None else {}
831
- write_region = {dim : write_region .get (dim , slice (None )) for dim in dims }
765
+ # We need to do this for both new and existing variables to ensure we're not
766
+ # writing to a partial chunk, even though we don't use the `encoding` value
767
+ # when writing to an existing variable. See
768
+ # https://github.com/pydata/xarray/issues/8371 for details.
769
+ encoding = extract_zarr_variable_encoding (
770
+ v ,
771
+ raise_on_invalid = vn in check_encoding_set ,
772
+ name = vn ,
773
+ safe_chunks = self ._safe_chunks ,
774
+ )
832
775
833
776
if name in existing_keys :
834
777
# existing variable
@@ -858,40 +801,7 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
858
801
)
859
802
else :
860
803
zarr_array = self .zarr_group [name ]
861
-
862
- if self ._append_dim is not None and self ._append_dim in dims :
863
- # resize existing variable
864
- append_axis = dims .index (self ._append_dim )
865
- assert write_region [self ._append_dim ] == slice (None )
866
- write_region [self ._append_dim ] = slice (
867
- zarr_array .shape [append_axis ], None
868
- )
869
-
870
- new_shape = list (zarr_array .shape )
871
- new_shape [append_axis ] += v .shape [append_axis ]
872
- zarr_array .resize (new_shape )
873
-
874
- zarr_shape = zarr_array .shape
875
-
876
- region = tuple (write_region [dim ] for dim in dims )
877
-
878
- # We need to do this for both new and existing variables to ensure we're not
879
- # writing to a partial chunk, even though we don't use the `encoding` value
880
- # when writing to an existing variable. See
881
- # https://github.com/pydata/xarray/issues/8371 for details.
882
- # Note: Ideally there should be two functions, one for validating the chunks and
883
- # another one for extracting the encoding.
884
- encoding = extract_zarr_variable_encoding (
885
- v ,
886
- raise_on_invalid = vn in check_encoding_set ,
887
- name = vn ,
888
- safe_chunks = self ._safe_chunks ,
889
- region = region ,
890
- mode = self ._mode ,
891
- shape = zarr_shape ,
892
- )
893
-
894
- if name not in existing_keys :
804
+ else :
895
805
# new variable
896
806
encoded_attrs = {}
897
807
# the magic for storing the hidden dimension data
@@ -923,6 +833,22 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
923
833
)
924
834
zarr_array = _put_attrs (zarr_array , encoded_attrs )
925
835
836
+ write_region = self ._write_region if self ._write_region is not None else {}
837
+ write_region = {dim : write_region .get (dim , slice (None )) for dim in dims }
838
+
839
+ if self ._append_dim is not None and self ._append_dim in dims :
840
+ # resize existing variable
841
+ append_axis = dims .index (self ._append_dim )
842
+ assert write_region [self ._append_dim ] == slice (None )
843
+ write_region [self ._append_dim ] = slice (
844
+ zarr_array .shape [append_axis ], None
845
+ )
846
+
847
+ new_shape = list (zarr_array .shape )
848
+ new_shape [append_axis ] += v .shape [append_axis ]
849
+ zarr_array .resize (new_shape )
850
+
851
+ region = tuple (write_region [dim ] for dim in dims )
926
852
writer .add (v .data , zarr_array , region )
927
853
928
854
def close (self ) -> None :
@@ -971,9 +897,9 @@ def _validate_and_autodetect_region(self, ds) -> None:
971
897
if not isinstance (region , dict ):
972
898
raise TypeError (f"``region`` must be a dict, got { type (region )} " )
973
899
if any (v == "auto" for v in region .values ()):
974
- if self ._mode not in [ "r+" , "a" ] :
900
+ if self ._mode != "r+" :
975
901
raise ValueError (
976
- f"``mode`` must be 'r+' or 'a' when using ``region='auto'``, got { self ._mode !r} "
902
+ f"``mode`` must be 'r+' when using ``region='auto'``, got { self ._mode !r} "
977
903
)
978
904
region = self ._auto_detect_regions (ds , region )
979
905
0 commit comments