-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Automatic region detection and transpose for to_zarr()
#8434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
5b62b7c
b0058ae
17fbca9
463736e
d10c029
7e3419b
5120f1f
c1326c4
809e9e8
9eb1b58
e1a2cb3
d4b8a0d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
_normalize_path, | ||
) | ||
from xarray.backends.locks import _get_scheduler | ||
from xarray.backends.zarr import open_zarr | ||
from xarray.core import indexing | ||
from xarray.core.combine import ( | ||
_infer_concat_order_from_positions, | ||
|
@@ -1446,10 +1447,54 @@ def save_mfdataset( | |
) | ||
|
||
|
||
def _validate_region(ds, region): | ||
def _auto_detect_region(ds_new, ds_orig, dim): | ||
# Create a mapping array of coordinates to indices on the original array | ||
coord = ds_orig[dim] | ||
da_map = DataArray(np.arange(coord.size), coords={dim: coord}) | ||
|
||
try: | ||
da_idxs = da_map.sel({dim: ds_new[dim]}) | ||
except KeyError as e: | ||
if "not all values found" in str(e): | ||
raise KeyError( | ||
f"Not all values of coordinate '{dim}' in the new array were" | ||
" found in the original store. Writing to a zarr region slice" | ||
" requires that no dimensions or metadata are changed by the write." | ||
) | ||
else: | ||
raise e | ||
|
||
if (da_idxs.diff(dim) != 1).any(): | ||
raise ValueError( | ||
f"The auto-detected region of coordinate '{dim}' for writing new data" | ||
" to the original store had non-contiguous indices. Writing to a zarr" | ||
" region slice requires that the new data constitute a contiguous subset" | ||
" of the original store." | ||
) | ||
|
||
dim_slice = slice(da_idxs.values[0], da_idxs.values[-1] + 1) | ||
|
||
return dim_slice | ||
|
||
|
||
def _auto_detect_regions(ds, region, open_kwargs): | ||
ds_original = open_zarr(**open_kwargs) | ||
for key, val in region.items(): | ||
if val == "auto": | ||
region[key] = _auto_detect_region(ds, ds_original, key) | ||
return region | ||
|
||
|
||
def _validate_and_autodetect_region(ds, region, open_kwargs) -> dict[str, slice]: | ||
if region == "auto": | ||
region = {dim: "auto" for dim in ds.dims} | ||
|
||
if not isinstance(region, dict): | ||
raise TypeError(f"``region`` must be a dict, got {type(region)}") | ||
|
||
if any(v == "auto" for v in region.values()): | ||
region = _auto_detect_regions(ds, region, open_kwargs) | ||
|
||
for k, v in region.items(): | ||
if k not in ds.dims: | ||
raise ValueError( | ||
|
@@ -1481,6 +1526,8 @@ def _validate_region(ds, region): | |
f".drop_vars({non_matching_vars!r})" | ||
) | ||
|
||
return region | ||
|
||
|
||
def _validate_datatypes_for_zarr_append(zstore, dataset): | ||
"""If variable exists in the store, confirm dtype of the data to append is compatible with | ||
|
@@ -1532,7 +1579,7 @@ def to_zarr( | |
compute: Literal[True] = True, | ||
consolidated: bool | None = None, | ||
append_dim: Hashable | None = None, | ||
region: Mapping[str, slice] | None = None, | ||
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, | ||
safe_chunks: bool = True, | ||
storage_options: dict[str, str] | None = None, | ||
zarr_version: int | None = None, | ||
|
@@ -1556,7 +1603,7 @@ def to_zarr( | |
compute: Literal[False], | ||
consolidated: bool | None = None, | ||
append_dim: Hashable | None = None, | ||
region: Mapping[str, slice] | None = None, | ||
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, | ||
safe_chunks: bool = True, | ||
storage_options: dict[str, str] | None = None, | ||
zarr_version: int | None = None, | ||
|
@@ -1578,7 +1625,7 @@ def to_zarr( | |
compute: bool = True, | ||
consolidated: bool | None = None, | ||
append_dim: Hashable | None = None, | ||
region: Mapping[str, slice] | None = None, | ||
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, | ||
safe_chunks: bool = True, | ||
storage_options: dict[str, str] | None = None, | ||
zarr_version: int | None = None, | ||
|
@@ -1643,7 +1690,15 @@ def to_zarr( | |
_validate_dataset_names(dataset) | ||
|
||
if region is not None: | ||
_validate_region(dataset, region) | ||
open_kwargs = dict( | ||
store=store, | ||
synchronizer=synchronizer, | ||
group=group, | ||
consolidated=consolidated, | ||
storage_options=storage_options, | ||
zarr_version=zarr_version, | ||
) | ||
region = _validate_and_autodetect_region(dataset, region, open_kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Somewhere around here might be the right place to drop the dimension coordinates for |
||
if append_dim is not None and append_dim in region: | ||
raise ValueError( | ||
f"cannot list the same dimension in both ``append_dim`` and " | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -5431,3 +5431,128 @@ def test_raise_writing_to_nczarr(self, mode) -> None: | |||||||||||||||||||||||||||||||||||||
def test_pickle_open_mfdataset_dataset(): | ||||||||||||||||||||||||||||||||||||||
ds = open_example_mfdataset(["bears.nc"]) | ||||||||||||||||||||||||||||||||||||||
assert_identical(ds, pickle.loads(pickle.dumps(ds))) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
@requires_zarr | ||||||||||||||||||||||||||||||||||||||
class TestZarrRegionAuto: | ||||||||||||||||||||||||||||||||||||||
def test_zarr_region_auto_all(self, tmp_path): | ||||||||||||||||||||||||||||||||||||||
x = np.arange(0, 50, 10) | ||||||||||||||||||||||||||||||||||||||
y = np.arange(0, 20, 2) | ||||||||||||||||||||||||||||||||||||||
data = np.ones((5, 10)) | ||||||||||||||||||||||||||||||||||||||
ds = xr.Dataset( | ||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||
"test": xr.DataArray( | ||||||||||||||||||||||||||||||||||||||
data, | ||||||||||||||||||||||||||||||||||||||
dims=("x", "y"), | ||||||||||||||||||||||||||||||||||||||
coords={"x": x, "y": y}, | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
ds.to_zarr(tmp_path / "test.zarr") | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
ds_region = 1 + ds.isel(x=slice(2, 4), y=slice(6, 8)) | ||||||||||||||||||||||||||||||||||||||
ds_region.to_zarr(tmp_path / "test.zarr", region="auto") | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
ds_updated = xr.open_zarr(tmp_path / "test.zarr") | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
expected = ds.copy() | ||||||||||||||||||||||||||||||||||||||
expected["test"][2:4, 6:8] += 1 | ||||||||||||||||||||||||||||||||||||||
assert_identical(ds_updated, expected) | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would be the place to add a test to verify that dimension coordinates are not being overwritten. To accomplish this, you could use a similar pattern to what @olimcc did in #8297: patch the store object so that it counts how many times a certain method has been called xarray/xarray/tests/test_backends.py Lines 2892 to 2909 in e5d163a
Here you would want to patch |
||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
def test_zarr_region_auto_mixed(self, tmp_path): | ||||||||||||||||||||||||||||||||||||||
x = np.arange(0, 50, 10) | ||||||||||||||||||||||||||||||||||||||
y = np.arange(0, 20, 2) | ||||||||||||||||||||||||||||||||||||||
data = np.ones((5, 10)) | ||||||||||||||||||||||||||||||||||||||
ds = xr.Dataset( | ||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||
"test": xr.DataArray( | ||||||||||||||||||||||||||||||||||||||
data, | ||||||||||||||||||||||||||||||||||||||
dims=("x", "y"), | ||||||||||||||||||||||||||||||||||||||
coords={"x": x, "y": y}, | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
ds.to_zarr(tmp_path / "test.zarr") | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
ds_region = 1 + ds.isel(x=slice(2, 4), y=slice(6, 8)) | ||||||||||||||||||||||||||||||||||||||
ds_region.to_zarr( | ||||||||||||||||||||||||||||||||||||||
tmp_path / "test.zarr", region={"x": "auto", "y": slice(6, 8)} | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
ds_updated = xr.open_zarr(tmp_path / "test.zarr") | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
expected = ds.copy() | ||||||||||||||||||||||||||||||||||||||
expected["test"][2:4, 6:8] += 1 | ||||||||||||||||||||||||||||||||||||||
assert_identical(ds_updated, expected) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
def test_zarr_region_auto_noncontiguous(self, tmp_path): | ||||||||||||||||||||||||||||||||||||||
x = np.arange(0, 50, 10) | ||||||||||||||||||||||||||||||||||||||
y = np.arange(0, 20, 2) | ||||||||||||||||||||||||||||||||||||||
data = np.ones((5, 10)) | ||||||||||||||||||||||||||||||||||||||
ds = xr.Dataset( | ||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||
"test": xr.DataArray( | ||||||||||||||||||||||||||||||||||||||
data, | ||||||||||||||||||||||||||||||||||||||
dims=("x", "y"), | ||||||||||||||||||||||||||||||||||||||
coords={"x": x, "y": y}, | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
ds.to_zarr(tmp_path / "test.zarr") | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
ds_region = 1 + ds.isel(x=[0, 2, 3], y=[5, 6]) | ||||||||||||||||||||||||||||||||||||||
with pytest.raises(ValueError): | ||||||||||||||||||||||||||||||||||||||
ds_region.to_zarr(tmp_path / "test.zarr", region={"x": "auto", "y": "auto"}) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
def test_zarr_region_auto_new_coord_vals(self, tmp_path): | ||||||||||||||||||||||||||||||||||||||
x = np.arange(0, 50, 10) | ||||||||||||||||||||||||||||||||||||||
y = np.arange(0, 20, 2) | ||||||||||||||||||||||||||||||||||||||
data = np.ones((5, 10)) | ||||||||||||||||||||||||||||||||||||||
ds = xr.Dataset( | ||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||
"test": xr.DataArray( | ||||||||||||||||||||||||||||||||||||||
data, | ||||||||||||||||||||||||||||||||||||||
dims=("x", "y"), | ||||||||||||||||||||||||||||||||||||||
coords={"x": x, "y": y}, | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
ds.to_zarr(tmp_path / "test.zarr") | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
x = np.arange(5, 55, 10) | ||||||||||||||||||||||||||||||||||||||
y = np.arange(0, 20, 2) | ||||||||||||||||||||||||||||||||||||||
data = np.ones((5, 10)) | ||||||||||||||||||||||||||||||||||||||
ds = xr.Dataset( | ||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||
"test": xr.DataArray( | ||||||||||||||||||||||||||||||||||||||
data, | ||||||||||||||||||||||||||||||||||||||
dims=("x", "y"), | ||||||||||||||||||||||||||||||||||||||
coords={"x": x, "y": y}, | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
ds_region = 1 + ds.isel(x=slice(2, 4), y=slice(6, 8)) | ||||||||||||||||||||||||||||||||||||||
with pytest.raises(KeyError): | ||||||||||||||||||||||||||||||||||||||
ds_region.to_zarr(tmp_path / "test.zarr", region={"x": "auto", "y": "auto"}) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
def test_zarr_region_transpose(tmp_path): | ||||||||||||||||||||||||||||||||||||||
slevang marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||
x = np.arange(0, 50, 10) | ||||||||||||||||||||||||||||||||||||||
y = np.arange(0, 20, 2) | ||||||||||||||||||||||||||||||||||||||
data = np.ones((5, 10)) | ||||||||||||||||||||||||||||||||||||||
ds = xr.Dataset( | ||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||
"test": xr.DataArray( | ||||||||||||||||||||||||||||||||||||||
data, | ||||||||||||||||||||||||||||||||||||||
dims=("x", "y"), | ||||||||||||||||||||||||||||||||||||||
coords={"x": x, "y": y}, | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
ds.to_zarr(tmp_path / "test.zarr") | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
ds_region = 1 + ds.isel(x=[0], y=[0]).transpose() | ||||||||||||||||||||||||||||||||||||||
ds_region.to_zarr( | ||||||||||||||||||||||||||||||||||||||
tmp_path / "test.zarr", region={"x": slice(0, 1), "y": slice(0, 1)} | ||||||||||||||||||||||||||||||||||||||
) |
Uh oh!
There was an error while loading. Please reload this page.