@@ -112,7 +112,9 @@ def __getitem__(self, key):
112
112
# could possibly have a work-around for 0d data here
113
113
114
114
115
- def _determine_zarr_chunks (enc_chunks , var_chunks , ndim , name , safe_chunks ):
115
+ def _determine_zarr_chunks (
116
+ enc_chunks , var_chunks , ndim , name , safe_chunks , region , mode , shape
117
+ ):
116
118
"""
117
119
Given encoding chunks (possibly None or []) and variable chunks
118
120
(possibly None or []).
@@ -163,7 +165,9 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
163
165
164
166
if len (enc_chunks_tuple ) != ndim :
165
167
# throw away encoding chunks, start over
166
- return _determine_zarr_chunks (None , var_chunks , ndim , name , safe_chunks )
168
+ return _determine_zarr_chunks (
169
+ None , var_chunks , ndim , name , safe_chunks , region , mode , shape
170
+ )
167
171
168
172
for x in enc_chunks_tuple :
169
173
if not isinstance (x , int ):
@@ -189,20 +193,59 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
189
193
# TODO: incorporate synchronizer to allow writes from multiple dask
190
194
# threads
191
195
if var_chunks and enc_chunks_tuple :
192
- for zchunk , dchunks in zip (enc_chunks_tuple , var_chunks , strict = True ):
193
- for dchunk in dchunks [:- 1 ]:
196
+ # If it is possible to write on partial chunks then it is not necessary to check
197
+ # the last one contained on the region
198
+ allow_partial_chunks = mode != "r+"
199
+
200
+ base_error = (
201
+ f"Specified zarr chunks encoding['chunks']={ enc_chunks_tuple !r} for "
202
+ f"variable named { name !r} would overlap multiple dask chunks { var_chunks !r} "
203
+ f"on the region { region } . "
204
+ f"Writing this array in parallel with dask could lead to corrupted data."
205
+ f"Consider either rechunking using `chunk()`, deleting "
206
+ f"or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
207
+ )
208
+
209
+ for zchunk , dchunks , interval , size in zip (
210
+ enc_chunks_tuple , var_chunks , region , shape , strict = True
211
+ ):
212
+ if not safe_chunks :
213
+ continue
214
+
215
+ for dchunk in dchunks [1 :- 1 ]:
194
216
if dchunk % zchunk :
195
- base_error = (
196
- f"Specified zarr chunks encoding['chunks']={ enc_chunks_tuple !r} for "
197
- f"variable named { name !r} would overlap multiple dask chunks { var_chunks !r} . "
198
- f"Writing this array in parallel with dask could lead to corrupted data."
199
- )
200
- if safe_chunks :
201
- raise ValueError (
202
- base_error
203
- + " Consider either rechunking using `chunk()`, deleting "
204
- "or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
205
- )
217
+ raise ValueError (base_error )
218
+
219
+ region_start = interval .start if interval .start else 0
220
+
221
+ if len (dchunks ) > 1 :
222
+ # The first border size is the amount of data that needs to be updated on the
223
+ # first chunk taking into account the region slice.
224
+ first_border_size = zchunk
225
+ if allow_partial_chunks :
226
+ first_border_size = zchunk - region_start % zchunk
227
+
228
+ if (dchunks [0 ] - first_border_size ) % zchunk :
229
+ raise ValueError (base_error )
230
+
231
+ if not allow_partial_chunks :
232
+ chunk_start = sum (dchunks [:- 1 ]) + region_start
233
+ if chunk_start % zchunk :
234
+ # The last chunk which can also be the only one is a partial chunk
235
+ # if it is not aligned at the beginning
236
+ raise ValueError (base_error )
237
+
238
+ region_stop = interval .stop if interval .stop else size
239
+
240
+ if size - region_stop + 1 < zchunk :
241
+ # If the region is covering the last chunk then check
242
+ # if the reminder with the default chunk size
243
+ # is equal to the size of the last chunk
244
+ if dchunks [- 1 ] % zchunk != size % zchunk :
245
+ raise ValueError (base_error )
246
+ elif dchunks [- 1 ] % zchunk :
247
+ raise ValueError (base_error )
248
+
206
249
return enc_chunks_tuple
207
250
208
251
raise AssertionError ("We should never get here. Function logic must be wrong." )
@@ -243,7 +286,14 @@ def _get_zarr_dims_and_attrs(zarr_obj, dimension_key, try_nczarr):
243
286
244
287
245
288
def extract_zarr_variable_encoding (
246
- variable , raise_on_invalid = False , name = None , safe_chunks = True
289
+ variable ,
290
+ raise_on_invalid = False ,
291
+ name = None ,
292
+ * ,
293
+ safe_chunks = True ,
294
+ region = None ,
295
+ mode = None ,
296
+ shape = None ,
247
297
):
248
298
"""
249
299
Extract zarr encoding dictionary from xarray Variable
@@ -252,12 +302,18 @@ def extract_zarr_variable_encoding(
252
302
----------
253
303
variable : Variable
254
304
raise_on_invalid : bool, optional
255
-
305
+ name: str | Hashable, optional
306
+ safe_chunks: bool, optional
307
+ region: tuple[slice, ...], optional
308
+ mode: str, optional
309
+ shape: tuple[int, ...], optional
256
310
Returns
257
311
-------
258
312
encoding : dict
259
313
Zarr encoding for `variable`
260
314
"""
315
+
316
+ shape = shape if shape else variable .shape
261
317
encoding = variable .encoding .copy ()
262
318
263
319
safe_to_drop = {"source" , "original_shape" }
@@ -285,7 +341,14 @@ def extract_zarr_variable_encoding(
285
341
del encoding [k ]
286
342
287
343
chunks = _determine_zarr_chunks (
288
- encoding .get ("chunks" ), variable .chunks , variable .ndim , name , safe_chunks
344
+ enc_chunks = encoding .get ("chunks" ),
345
+ var_chunks = variable .chunks ,
346
+ ndim = variable .ndim ,
347
+ name = name ,
348
+ safe_chunks = safe_chunks ,
349
+ region = region ,
350
+ mode = mode ,
351
+ shape = shape ,
289
352
)
290
353
encoding ["chunks" ] = chunks
291
354
return encoding
@@ -762,16 +825,10 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
762
825
if v .encoding == {"_FillValue" : None } and fill_value is None :
763
826
v .encoding = {}
764
827
765
- # We need to do this for both new and existing variables to ensure we're not
766
- # writing to a partial chunk, even though we don't use the `encoding` value
767
- # when writing to an existing variable. See
768
- # https://github.com/pydata/xarray/issues/8371 for details.
769
- encoding = extract_zarr_variable_encoding (
770
- v ,
771
- raise_on_invalid = vn in check_encoding_set ,
772
- name = vn ,
773
- safe_chunks = self ._safe_chunks ,
774
- )
828
+ zarr_array = None
829
+ zarr_shape = None
830
+ write_region = self ._write_region if self ._write_region is not None else {}
831
+ write_region = {dim : write_region .get (dim , slice (None )) for dim in dims }
775
832
776
833
if name in existing_keys :
777
834
# existing variable
@@ -801,7 +858,40 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
801
858
)
802
859
else :
803
860
zarr_array = self .zarr_group [name ]
804
- else :
861
+
862
+ if self ._append_dim is not None and self ._append_dim in dims :
863
+ # resize existing variable
864
+ append_axis = dims .index (self ._append_dim )
865
+ assert write_region [self ._append_dim ] == slice (None )
866
+ write_region [self ._append_dim ] = slice (
867
+ zarr_array .shape [append_axis ], None
868
+ )
869
+
870
+ new_shape = list (zarr_array .shape )
871
+ new_shape [append_axis ] += v .shape [append_axis ]
872
+ zarr_array .resize (new_shape )
873
+
874
+ zarr_shape = zarr_array .shape
875
+
876
+ region = tuple (write_region [dim ] for dim in dims )
877
+
878
+ # We need to do this for both new and existing variables to ensure we're not
879
+ # writing to a partial chunk, even though we don't use the `encoding` value
880
+ # when writing to an existing variable. See
881
+ # https://github.com/pydata/xarray/issues/8371 for details.
882
+ # Note: Ideally there should be two functions, one for validating the chunks and
883
+ # another one for extracting the encoding.
884
+ encoding = extract_zarr_variable_encoding (
885
+ v ,
886
+ raise_on_invalid = vn in check_encoding_set ,
887
+ name = vn ,
888
+ safe_chunks = self ._safe_chunks ,
889
+ region = region ,
890
+ mode = self ._mode ,
891
+ shape = zarr_shape ,
892
+ )
893
+
894
+ if name not in existing_keys :
805
895
# new variable
806
896
encoded_attrs = {}
807
897
# the magic for storing the hidden dimension data
@@ -833,22 +923,6 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
833
923
)
834
924
zarr_array = _put_attrs (zarr_array , encoded_attrs )
835
925
836
- write_region = self ._write_region if self ._write_region is not None else {}
837
- write_region = {dim : write_region .get (dim , slice (None )) for dim in dims }
838
-
839
- if self ._append_dim is not None and self ._append_dim in dims :
840
- # resize existing variable
841
- append_axis = dims .index (self ._append_dim )
842
- assert write_region [self ._append_dim ] == slice (None )
843
- write_region [self ._append_dim ] = slice (
844
- zarr_array .shape [append_axis ], None
845
- )
846
-
847
- new_shape = list (zarr_array .shape )
848
- new_shape [append_axis ] += v .shape [append_axis ]
849
- zarr_array .resize (new_shape )
850
-
851
- region = tuple (write_region [dim ] for dim in dims )
852
926
writer .add (v .data , zarr_array , region )
853
927
854
928
def close (self ) -> None :
@@ -897,9 +971,9 @@ def _validate_and_autodetect_region(self, ds) -> None:
897
971
if not isinstance (region , dict ):
898
972
raise TypeError (f"``region`` must be a dict, got { type (region )} " )
899
973
if any (v == "auto" for v in region .values ()):
900
- if self ._mode != "r+" :
974
+ if self ._mode not in [ "r+" , "a" ] :
901
975
raise ValueError (
902
- f"``mode`` must be 'r+' when using ``region='auto'``, got { self ._mode !r} "
976
+ f"``mode`` must be 'r+' or 'a' when using ``region='auto'``, got { self ._mode !r} "
903
977
)
904
978
region = self ._auto_detect_regions (ds , region )
905
979
0 commit comments