Skip to content

Commit 6057128

Browse files
TomNicholaspre-commit-ci[bot]dcheriankeewis
authored
Avoid auto creation of indexes in concat (#8872)
* test not creating indexes on concatenation * construct result dataset using Coordinates object with indexes passed explicitly * remove unnecessary overwriting of indexes * ConcatenatableArray class * use ConcatenableArray in tests * add regression tests * fix by performing check * refactor assert_valid_explicit_coords and rename dims->sizes * Revert "add regression tests" This reverts commit beb665a. * Revert "fix by performing check" This reverts commit 22f361d. * Revert "refactor assert_valid_explicit_coords and rename dims->sizes" This reverts commit 55166fc. * fix failing test * possible fix for failing groupby test * Revert "possible fix for failing groupby test" This reverts commit 6e9ead6. * test expand_dims doesn't create Index * add option to not create 1D index in expand_dims * refactor tests to consider data variables and coordinate variables separately * test expand_dims doesn't create Index * add option to not create 1D index in expand_dims * refactor tests to consider data variables and coordinate variables separately * fix bug causing new test to fail * test index auto-creation when iterable passed as new coordinate values * make test for iterable pass * added kwarg to dataarray * whatsnew * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert "refactor tests to consider data variables and coordinate variables separately" This reverts commit ba5627e. * Revert "add option to not create 1D index in expand_dims" This reverts commit 95d453c. * test that concat doesn't raise if create_1d_index=False * make test pass by passing create_1d_index down through concat * assert that an UnexpectedDataAccess error is raised when create_1d_index=True * eliminate possibility of xarray internals bypassing UnexpectedDataAccess error by accessing .array * update tests to use private versions of assertions * create_1d_index->create_index * Update doc/whats-new.rst Co-authored-by: Deepak Cherian <[email protected]> * Rename create_1d_index -> create_index * fix ConcatenatableArray * formatting * whatsnew * add new create_index kwarg to overloads * split vars into data_vars and coord_vars in one loop * avoid mypy error by using new variable name * warn if create_index=True but no index created because dimension variable was a data var not a coord * add string marks in warning message * regression test for dtype changing in to_stacked_array * correct doctest * Remove outdated comment * test we can skip creation of indexes during shape promotion * make shape promotion test pass * point to issue in whatsnew * don't create dimension coordinates just to drop them at the end * Remove ToDo about not using Coordinates object to pass indexes Co-authored-by: Deepak Cherian <[email protected]> * get rid of unlabeled_dims variable entirely * move ConcatenatableArray and similar to new file * formatting nit Co-authored-by: Justus Magin <[email protected]> * renamed create_index -> create_index_for_new_dim in concat * renamed create_index -> create_index_for_new_dim in expand_dims * fix incorrect arg name * add example to docstring * add example of using new kwarg to docstring of expand_dims * add example of using new kwarg to docstring of concat * re-nit the nit Co-authored-by: Justus Magin <[email protected]> * more instances of the nit * fix docstring doctest formatting nit --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: Justus Magin <[email protected]>
1 parent 71661d5 commit 6057128

File tree

8 files changed

+405
-84
lines changed

8 files changed

+405
-84
lines changed

doc/whats-new.rst

+4-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ New Features
3232
- :py:func:`testing.assert_allclose`/:py:func:`testing.assert_equal` now accept a new argument `check_dims="transpose"`, controlling whether a transposed array is considered equal. (:issue:`5733`, :pull:`8991`)
3333
By `Ignacio Martinez Vazquez <https://github.com/ignamv>`_.
3434
- Added the option to avoid automatically creating 1D pandas indexes in :py:meth:`Dataset.expand_dims()`, by passing the new kwarg
35-
`create_index=False`. (:pull:`8960`)
35+
`create_index_for_new_dim=False`. (:pull:`8960`)
36+
By `Tom Nicholas <https://github.com/TomNicholas>`_.
37+
- Avoid automatically re-creating 1D pandas indexes in :py:func:`concat()`. Also added option to avoid creating 1D indexes for
38+
new dimension coordinates by passing the new kwarg `create_index_for_new_dim=False`. (:issue:`8871`, :pull:`8872`)
3639
By `Tom Nicholas <https://github.com/TomNicholas>`_.
3740

3841
Breaking changes

xarray/core/concat.py

+52-15
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from xarray.core import dtypes, utils
1010
from xarray.core.alignment import align, reindex_variables
11+
from xarray.core.coordinates import Coordinates
1112
from xarray.core.duck_array_ops import lazy_array_equiv
1213
from xarray.core.indexes import Index, PandasIndex
1314
from xarray.core.merge import (
@@ -42,6 +43,7 @@ def concat(
4243
fill_value: object = dtypes.NA,
4344
join: JoinOptions = "outer",
4445
combine_attrs: CombineAttrsOptions = "override",
46+
create_index_for_new_dim: bool = True,
4547
) -> T_Dataset: ...
4648

4749

@@ -56,6 +58,7 @@ def concat(
5658
fill_value: object = dtypes.NA,
5759
join: JoinOptions = "outer",
5860
combine_attrs: CombineAttrsOptions = "override",
61+
create_index_for_new_dim: bool = True,
5962
) -> T_DataArray: ...
6063

6164

@@ -69,6 +72,7 @@ def concat(
6972
fill_value=dtypes.NA,
7073
join: JoinOptions = "outer",
7174
combine_attrs: CombineAttrsOptions = "override",
75+
create_index_for_new_dim: bool = True,
7276
):
7377
"""Concatenate xarray objects along a new or existing dimension.
7478
@@ -162,6 +166,8 @@ def concat(
162166
163167
If a callable, it must expect a sequence of ``attrs`` dicts and a context object
164168
as its only parameters.
169+
create_index_for_new_dim : bool, default: True
170+
Whether to create a new ``PandasIndex`` object when the objects being concatenated contain scalar variables named ``dim``.
165171
166172
Returns
167173
-------
@@ -217,6 +223,25 @@ def concat(
217223
x (new_dim) <U1 8B 'a' 'b'
218224
* y (y) int64 24B 10 20 30
219225
* new_dim (new_dim) int64 16B -90 -100
226+
227+
# Concatenate a scalar variable along a new dimension of the same name with and without creating a new index
228+
229+
>>> ds = xr.Dataset(coords={"x": 0})
230+
>>> xr.concat([ds, ds], dim="x")
231+
<xarray.Dataset> Size: 16B
232+
Dimensions: (x: 2)
233+
Coordinates:
234+
* x (x) int64 16B 0 0
235+
Data variables:
236+
*empty*
237+
238+
>>> xr.concat([ds, ds], dim="x").indexes
239+
Indexes:
240+
x Index([0, 0], dtype='int64', name='x')
241+
242+
>>> xr.concat([ds, ds], dim="x", create_index_for_new_dim=False).indexes
243+
Indexes:
244+
*empty*
220245
"""
221246
# TODO: add ignore_index arguments copied from pandas.concat
222247
# TODO: support concatenating scalar coordinates even if the concatenated
@@ -245,6 +270,7 @@ def concat(
245270
fill_value=fill_value,
246271
join=join,
247272
combine_attrs=combine_attrs,
273+
create_index_for_new_dim=create_index_for_new_dim,
248274
)
249275
elif isinstance(first_obj, Dataset):
250276
return _dataset_concat(
@@ -257,6 +283,7 @@ def concat(
257283
fill_value=fill_value,
258284
join=join,
259285
combine_attrs=combine_attrs,
286+
create_index_for_new_dim=create_index_for_new_dim,
260287
)
261288
else:
262289
raise TypeError(
@@ -439,7 +466,7 @@ def _parse_datasets(
439466
if dim in dims:
440467
continue
441468

442-
if dim not in dim_coords:
469+
if dim in ds.coords and dim not in dim_coords:
443470
dim_coords[dim] = ds.coords[dim].variable
444471
dims = dims | set(ds.dims)
445472

@@ -456,6 +483,7 @@ def _dataset_concat(
456483
fill_value: Any = dtypes.NA,
457484
join: JoinOptions = "outer",
458485
combine_attrs: CombineAttrsOptions = "override",
486+
create_index_for_new_dim: bool = True,
459487
) -> T_Dataset:
460488
"""
461489
Concatenate a sequence of datasets along a new or existing dimension
@@ -489,7 +517,6 @@ def _dataset_concat(
489517
datasets
490518
)
491519
dim_names = set(dim_coords)
492-
unlabeled_dims = dim_names - coord_names
493520

494521
both_data_and_coords = coord_names & data_names
495522
if both_data_and_coords:
@@ -502,15 +529,18 @@ def _dataset_concat(
502529

503530
# case where concat dimension is a coordinate or data_var but not a dimension
504531
if (dim in coord_names or dim in data_names) and dim not in dim_names:
505-
datasets = [ds.expand_dims(dim) for ds in datasets]
532+
datasets = [
533+
ds.expand_dims(dim, create_index_for_new_dim=create_index_for_new_dim)
534+
for ds in datasets
535+
]
506536

507537
# determine which variables to concatenate
508538
concat_over, equals, concat_dim_lengths = _calc_concat_over(
509539
datasets, dim, dim_names, data_vars, coords, compat
510540
)
511541

512542
# determine which variables to merge, and then merge them according to compat
513-
variables_to_merge = (coord_names | data_names) - concat_over - unlabeled_dims
543+
variables_to_merge = (coord_names | data_names) - concat_over
514544

515545
result_vars = {}
516546
result_indexes = {}
@@ -567,7 +597,8 @@ def get_indexes(name):
567597
var = ds._variables[name]
568598
if not var.dims:
569599
data = var.set_dims(dim).values
570-
yield PandasIndex(data, dim, coord_dtype=var.dtype)
600+
if create_index_for_new_dim:
601+
yield PandasIndex(data, dim, coord_dtype=var.dtype)
571602

572603
# create concatenation index, needed for later reindexing
573604
file_start_indexes = np.append(0, np.cumsum(concat_dim_lengths))
@@ -646,29 +677,33 @@ def get_indexes(name):
646677
# preserves original variable order
647678
result_vars[name] = result_vars.pop(name)
648679

649-
result = type(datasets[0])(result_vars, attrs=result_attrs)
650-
651-
absent_coord_names = coord_names - set(result.variables)
680+
absent_coord_names = coord_names - set(result_vars)
652681
if absent_coord_names:
653682
raise ValueError(
654683
f"Variables {absent_coord_names!r} are coordinates in some datasets but not others."
655684
)
656-
result = result.set_coords(coord_names)
657-
result.encoding = result_encoding
658685

659-
result = result.drop_vars(unlabeled_dims, errors="ignore")
686+
result_data_vars = {}
687+
coord_vars = {}
688+
for name, result_var in result_vars.items():
689+
if name in coord_names:
690+
coord_vars[name] = result_var
691+
else:
692+
result_data_vars[name] = result_var
660693

661694
if index is not None:
662-
# add concat index / coordinate last to ensure that its in the final Dataset
663695
if dim_var is not None:
664696
index_vars = index.create_variables({dim: dim_var})
665697
else:
666698
index_vars = index.create_variables()
667-
result[dim] = index_vars[dim]
699+
700+
coord_vars[dim] = index_vars[dim]
668701
result_indexes[dim] = index
669702

670-
# TODO: add indexes at Dataset creation (when it is supported)
671-
result = result._overwrite_indexes(result_indexes)
703+
coords_obj = Coordinates(coord_vars, indexes=result_indexes)
704+
705+
result = type(datasets[0])(result_data_vars, coords=coords_obj, attrs=result_attrs)
706+
result.encoding = result_encoding
672707

673708
return result
674709

@@ -683,6 +718,7 @@ def _dataarray_concat(
683718
fill_value: object = dtypes.NA,
684719
join: JoinOptions = "outer",
685720
combine_attrs: CombineAttrsOptions = "override",
721+
create_index_for_new_dim: bool = True,
686722
) -> T_DataArray:
687723
from xarray.core.dataarray import DataArray
688724

@@ -719,6 +755,7 @@ def _dataarray_concat(
719755
fill_value=fill_value,
720756
join=join,
721757
combine_attrs=combine_attrs,
758+
create_index_for_new_dim=create_index_for_new_dim,
722759
)
723760

724761
merged_attrs = merge_attrs([da.attrs for da in arrays], combine_attrs)

xarray/core/dataarray.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -2558,7 +2558,7 @@ def expand_dims(
25582558
self,
25592559
dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None,
25602560
axis: None | int | Sequence[int] = None,
2561-
create_index: bool = True,
2561+
create_index_for_new_dim: bool = True,
25622562
**dim_kwargs: Any,
25632563
) -> Self:
25642564
"""Return a new object with an additional axis (or axes) inserted at
@@ -2569,7 +2569,7 @@ def expand_dims(
25692569
coordinate consisting of a single value.
25702570
25712571
The automatic creation of indexes to back new 1D coordinate variables
2572-
controlled by the create_index kwarg.
2572+
controlled by the create_index_for_new_dim kwarg.
25732573
25742574
Parameters
25752575
----------
@@ -2586,8 +2586,8 @@ def expand_dims(
25862586
multiple axes are inserted. In this case, dim arguments should be
25872587
same length list. If axis=None is passed, all the axes will be
25882588
inserted to the start of the result array.
2589-
create_index : bool, default is True
2590-
Whether to create new PandasIndex objects for any new 1D coordinate variables.
2589+
create_index_for_new_dim : bool, default: True
2590+
Whether to create new ``PandasIndex`` objects when the object being expanded contains scalar variables with names in ``dim``.
25912591
**dim_kwargs : int or sequence or ndarray
25922592
The keywords are arbitrary dimensions being inserted and the values
25932593
are either the lengths of the new dims (if int is given), or their
@@ -2651,7 +2651,9 @@ def expand_dims(
26512651
dim = {dim: 1}
26522652

26532653
dim = either_dict_or_kwargs(dim, dim_kwargs, "expand_dims")
2654-
ds = self._to_temp_dataset().expand_dims(dim, axis, create_index=create_index)
2654+
ds = self._to_temp_dataset().expand_dims(
2655+
dim, axis, create_index_for_new_dim=create_index_for_new_dim
2656+
)
26552657
return self._from_temp_dataset(ds)
26562658

26572659
def set_index(

xarray/core/dataset.py

+35-8
Original file line numberDiff line numberDiff line change
@@ -4513,7 +4513,7 @@ def expand_dims(
45134513
self,
45144514
dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None,
45154515
axis: None | int | Sequence[int] = None,
4516-
create_index: bool = True,
4516+
create_index_for_new_dim: bool = True,
45174517
**dim_kwargs: Any,
45184518
) -> Self:
45194519
"""Return a new object with an additional axis (or axes) inserted at
@@ -4524,7 +4524,7 @@ def expand_dims(
45244524
coordinate consisting of a single value.
45254525
45264526
The automatic creation of indexes to back new 1D coordinate variables
4527-
controlled by the create_index kwarg.
4527+
controlled by the create_index_for_new_dim kwarg.
45284528
45294529
Parameters
45304530
----------
@@ -4541,8 +4541,8 @@ def expand_dims(
45414541
multiple axes are inserted. In this case, dim arguments should be
45424542
same length list. If axis=None is passed, all the axes will be
45434543
inserted to the start of the result array.
4544-
create_index : bool, default is True
4545-
Whether to create new PandasIndex objects for any new 1D coordinate variables.
4544+
create_index_for_new_dim : bool, default: True
4545+
Whether to create new ``PandasIndex`` objects when the object being expanded contains scalar variables with names in ``dim``.
45464546
**dim_kwargs : int or sequence or ndarray
45474547
The keywords are arbitrary dimensions being inserted and the values
45484548
are either the lengths of the new dims (if int is given), or their
@@ -4612,6 +4612,33 @@ def expand_dims(
46124612
Data variables:
46134613
temperature (y, x, time) float64 96B 0.5488 0.7152 0.6028 ... 0.7917 0.5289
46144614
4615+
# Expand a scalar variable along a new dimension of the same name with and without creating a new index
4616+
4617+
>>> ds = xr.Dataset(coords={"x": 0})
4618+
>>> ds
4619+
<xarray.Dataset> Size: 8B
4620+
Dimensions: ()
4621+
Coordinates:
4622+
x int64 8B 0
4623+
Data variables:
4624+
*empty*
4625+
4626+
>>> ds.expand_dims("x")
4627+
<xarray.Dataset> Size: 8B
4628+
Dimensions: (x: 1)
4629+
Coordinates:
4630+
* x (x) int64 8B 0
4631+
Data variables:
4632+
*empty*
4633+
4634+
>>> ds.expand_dims("x").indexes
4635+
Indexes:
4636+
x Index([0], dtype='int64', name='x')
4637+
4638+
>>> ds.expand_dims("x", create_index_for_new_dim=False).indexes
4639+
Indexes:
4640+
*empty*
4641+
46154642
See Also
46164643
--------
46174644
DataArray.expand_dims
@@ -4663,7 +4690,7 @@ def expand_dims(
46634690
# value within the dim dict to the length of the iterable
46644691
# for later use.
46654692

4666-
if create_index:
4693+
if create_index_for_new_dim:
46674694
index = PandasIndex(v, k)
46684695
indexes[k] = index
46694696
name_and_new_1d_var = index.create_variables()
@@ -4705,14 +4732,14 @@ def expand_dims(
47054732
variables[k] = v.set_dims(dict(all_dims))
47064733
else:
47074734
if k not in variables:
4708-
if k in coord_names and create_index:
4735+
if k in coord_names and create_index_for_new_dim:
47094736
# If dims includes a label of a non-dimension coordinate,
47104737
# it will be promoted to a 1D coordinate with a single value.
47114738
index, index_vars = create_default_index_implicit(v.set_dims(k))
47124739
indexes[k] = index
47134740
variables.update(index_vars)
47144741
else:
4715-
if create_index:
4742+
if create_index_for_new_dim:
47164743
warnings.warn(
47174744
f"No index created for dimension {k} because variable {k} is not a coordinate. "
47184745
f"To create an index for {k}, please first call `.set_coords('{k}')` on this object.",
@@ -5400,7 +5427,7 @@ def to_stacked_array(
54005427
[3, 4, 5, 7]])
54015428
Coordinates:
54025429
* z (z) object 32B MultiIndex
5403-
* variable (z) object 32B 'a' 'a' 'a' 'b'
5430+
* variable (z) <U1 16B 'a' 'a' 'a' 'b'
54045431
* y (z) object 32B 'u' 'v' 'w' nan
54055432
Dimensions without coordinates: x
54065433

0 commit comments

Comments
 (0)