Skip to content

Commit 4cd2c4b

Browse files
committed
Add drop=True argument to isel, sel and squeeze
Fixes GH242 This is useful for getting rid of extraneous scalar variables that arise from indexing, and in particular will resolve an issue for optional indexes: pydata#1017 (comment)
1 parent d01077a commit 4cd2c4b

File tree

7 files changed

+132
-27
lines changed

7 files changed

+132
-27
lines changed

doc/whats-new.rst

+6
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,12 @@ Enhancements
101101
providing consistent access to dimension length on both ``Dataset`` and
102102
``DataArray`` (:issue:`921`).
103103
By `Stephan Hoyer <https://github.com/shoyer>`_.
104+
- New keyword argument ``drop=True`` for :py:meth:`~DataArray.sel`,
105+
:py:meth:`~DataArray.isel` and :py:meth:`~DataArray.squeeze` for dropping
106+
scalar coordinates that arise from indexing.
107+
``DataArray`` (:issue:`242`).
108+
By `Stephan Hoyer <https://github.com/shoyer>`_.
109+
104110
- New top-level functions :py:func:`~xarray.full_like`,
105111
:py:func:`~xarray.zeros_like`, and :py:func:`~xarray.ones_like`
106112
By `Guido Imperiale <https://github.com/crusaderky>`_.

xarray/core/common.py

+22-16
Original file line numberDiff line numberDiff line change
@@ -249,10 +249,24 @@ def __dir__(self):
249249
return sorted(set(dir(type(self)) + extra_attrs))
250250

251251

252-
class SharedMethodsMixin(object):
253-
"""Shared methods for Dataset, DataArray and Variable."""
252+
def get_squeeze_dims(xarray_obj, dim):
253+
"""Get a list of dimensions to squeeze out.
254+
"""
255+
if dim is None:
256+
dim = [d for d, s in xarray_obj.sizes.items() if s == 1]
257+
else:
258+
if isinstance(dim, basestring):
259+
dim = [dim]
260+
if any(xarray_obj.sizes[k] > 1 for k in dim):
261+
raise ValueError('cannot select a dimension to squeeze out '
262+
'which has length greater than one')
263+
return dim
264+
265+
266+
class BaseDataObject(AttrAccessMixin):
267+
"""Shared base class for Dataset and DataArray."""
254268

255-
def squeeze(self, dim=None):
269+
def squeeze(self, dim=None, drop=False):
256270
"""Return a new object with squeezed data.
257271
258272
Parameters
@@ -261,6 +275,9 @@ def squeeze(self, dim=None):
261275
Selects a subset of the length one dimensions. If a dimension is
262276
selected with length greater than one, an error is raised. If
263277
None, all length one dimensions are squeezed.
278+
drop : bool, optional
279+
If ``drop=True``, drop squeezed coordinates instead of making them
280+
scalar.
264281
265282
Returns
266283
-------
@@ -272,19 +289,8 @@ def squeeze(self, dim=None):
272289
--------
273290
numpy.squeeze
274291
"""
275-
if dim is None:
276-
dim = [d for d, s in self.sizes.items() if s == 1]
277-
else:
278-
if isinstance(dim, basestring):
279-
dim = [dim]
280-
if any(self.sizes[k] > 1 for k in dim):
281-
raise ValueError('cannot select a dimension to squeeze out '
282-
'which has length greater than one')
283-
return self.isel(**{d: 0 for d in dim})
284-
285-
286-
class BaseDataObject(SharedMethodsMixin, AttrAccessMixin):
287-
"""Shared base class for Dataset and DataArray."""
292+
dims = get_squeeze_dims(self, dim)
293+
return self.isel(drop=drop, **{d: 0 for d in dims})
288294

289295
def _calc_assign_results(self, kwargs):
290296
results = SortedKeysDict()

xarray/core/dataarray.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ def chunk(self, chunks=None):
640640
ds = self._to_temp_dataset().chunk(chunks)
641641
return self._from_temp_dataset(ds)
642642

643-
def isel(self, **indexers):
643+
def isel(self, drop=False, **indexers):
644644
"""Return a new DataArray whose dataset is given by integer indexing
645645
along the specified dimension(s).
646646
@@ -649,10 +649,10 @@ def isel(self, **indexers):
649649
Dataset.isel
650650
DataArray.sel
651651
"""
652-
ds = self._to_temp_dataset().isel(**indexers)
652+
ds = self._to_temp_dataset().isel(drop=drop, **indexers)
653653
return self._from_temp_dataset(ds)
654654

655-
def sel(self, method=None, tolerance=None, **indexers):
655+
def sel(self, method=None, tolerance=None, drop=False, **indexers):
656656
"""Return a new DataArray whose dataset is given by selecting
657657
index labels along the specified dimension(s).
658658
@@ -664,7 +664,8 @@ def sel(self, method=None, tolerance=None, **indexers):
664664
pos_indexers, new_indexes = indexing.remap_label_indexers(
665665
self, indexers, method=method, tolerance=tolerance
666666
)
667-
return self.isel(**pos_indexers)._replace_indexes(new_indexes)
667+
result = self.isel(drop=drop, **pos_indexers)
668+
return result._replace_indexes(new_indexes)
668669

669670
def isel_points(self, dim='points', **indexers):
670671
"""Return a new DataArray whose dataset is given by pointwise integer

xarray/core/dataset.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,7 @@ def maybe_chunk(name, var, chunks):
892892
for k, v in self.variables.items()])
893893
return self._replace_vars_and_dims(variables)
894894

895-
def isel(self, **indexers):
895+
def isel(self, drop=False, **indexers):
896896
"""Returns a new dataset with each array indexed along the specified
897897
dimension(s).
898898
@@ -902,6 +902,9 @@ def isel(self, **indexers):
902902
903903
Parameters
904904
----------
905+
drop : bool, optional
906+
If ``drop=True``, drop coordinates variables indexed by integers
907+
instead of making them scalar.
905908
**indexers : {dim: indexer, ...}
906909
Keyword arguments with names matching dimensions and values given
907910
by integers, slice objects or arrays.
@@ -935,10 +938,13 @@ def isel(self, **indexers):
935938
variables = OrderedDict()
936939
for name, var in iteritems(self._variables):
937940
var_indexers = dict((k, v) for k, v in indexers if k in var.dims)
938-
variables[name] = var.isel(**var_indexers)
939-
return self._replace_vars_and_dims(variables)
941+
new_var = var.isel(**var_indexers)
942+
if not (drop and name in var_indexers):
943+
variables[name] = new_var
944+
coord_names = set(self._coord_names) & set(variables)
945+
return self._replace_vars_and_dims(variables, coord_names=coord_names)
940946

941-
def sel(self, method=None, tolerance=None, **indexers):
947+
def sel(self, method=None, tolerance=None, drop=False, **indexers):
942948
"""Returns a new dataset with each array indexed by tick labels
943949
along the specified dimension(s).
944950
@@ -969,6 +975,9 @@ def sel(self, method=None, tolerance=None, **indexers):
969975
matches. The values of the index at the matching locations most
970976
satisfy the equation ``abs(index[indexer] - target) <= tolerance``.
971977
Requires pandas>=0.17.
978+
drop : bool, optional
979+
If ``drop=True``, drop coordinates variables in `indexers` instead
980+
of making them scalar.
972981
**indexers : {dim: indexer, ...}
973982
Keyword arguments with names matching dimensions and values given
974983
by scalars, slices or arrays of tick labels. For dimensions with
@@ -994,7 +1003,8 @@ def sel(self, method=None, tolerance=None, **indexers):
9941003
pos_indexers, new_indexes = indexing.remap_label_indexers(
9951004
self, indexers, method=method, tolerance=tolerance
9961005
)
997-
return self.isel(**pos_indexers)._replace_indexes(new_indexes)
1006+
result = self.isel(drop=drop, **pos_indexers)
1007+
return result._replace_indexes(new_indexes)
9981008

9991009
def isel_points(self, dim='points', **indexers):
10001010
"""Returns a new dataset with each array indexed pointwise along the

xarray/core/variable.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,7 @@ def _as_array_or_item(data):
196196
return data
197197

198198

199-
class Variable(common.AbstractArray, common.SharedMethodsMixin,
200-
utils.NdimSizeLenMixin):
199+
class Variable(common.AbstractArray, utils.NdimSizeLenMixin):
201200

202201
"""A netcdf-like variable consisting of dimensions, data and attributes
203202
which describe a single Array. A single Variable object is not fully
@@ -553,6 +552,29 @@ def isel(self, **indexers):
553552
key[i] = indexers[dim]
554553
return self[tuple(key)]
555554

555+
def squeeze(self, dim=None):
556+
"""Return a new object with squeezed data.
557+
558+
Parameters
559+
----------
560+
dim : None or str or tuple of str, optional
561+
Selects a subset of the length one dimensions. If a dimension is
562+
selected with length greater than one, an error is raised. If
563+
None, all length one dimensions are squeezed.
564+
565+
Returns
566+
-------
567+
squeezed : same type as caller
568+
This object, but with with all or a subset of the dimensions of
569+
length 1 removed.
570+
571+
See Also
572+
--------
573+
numpy.squeeze
574+
"""
575+
dims = common.get_squeeze_dims(self, dim)
576+
return self.isel(**{d: 0 for d in dims})
577+
556578
def _shift_one_dim(self, dim, count):
557579
axis = self.get_axis_num(dim)
558580

xarray/test/test_dataarray.py

+30
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,26 @@ def test_sel_method(self):
498498
with self.assertRaisesRegexp(NotImplementedError, 'tolerance'):
499499
data.sel(x=[0.9, 1.9], method='backfill', tolerance=1)
500500

501+
def test_sel_drop(self):
502+
data = DataArray([1, 2, 3], [('x', [0, 1, 2])])
503+
expected = DataArray(1)
504+
selected = data.sel(x=0, drop=True)
505+
self.assertDataArrayIdentical(expected, selected)
506+
507+
expected = DataArray(1, {'x': 0})
508+
selected = data.sel(x=0, drop=False)
509+
self.assertDataArrayIdentical(expected, selected)
510+
511+
def test_isel_drop(self):
512+
data = DataArray([1, 2, 3], [('x', [0, 1, 2])])
513+
expected = DataArray(1)
514+
selected = data.isel(x=0, drop=True)
515+
self.assertDataArrayIdentical(expected, selected)
516+
517+
expected = DataArray(1, {'x': 0})
518+
selected = data.isel(x=0, drop=False)
519+
self.assertDataArrayIdentical(expected, selected)
520+
501521
def test_isel_points(self):
502522
shape = (10, 5, 6)
503523
np_array = np.random.random(shape)
@@ -1064,6 +1084,16 @@ def test_transpose(self):
10641084
def test_squeeze(self):
10651085
self.assertVariableEqual(self.dv.variable.squeeze(), self.dv.squeeze())
10661086

1087+
def test_squeeze_drop(self):
1088+
array = DataArray([1], [('x', [0])])
1089+
expected = DataArray(1)
1090+
actual = array.squeeze(drop=True)
1091+
self.assertDataArrayIdentical(expected, actual)
1092+
1093+
expected = DataArray(1, {'x': 0})
1094+
actual = array.squeeze(drop=False)
1095+
self.assertDataArrayIdentical(expected, actual)
1096+
10671097
def test_drop_coordinates(self):
10681098
expected = DataArray(np.random.randn(2, 3), dims=['x', 'y'])
10691099
arr = expected.copy()

xarray/test/test_dataset.py

+30
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,26 @@ def test_sel(self):
815815
self.assertDatasetEqual(data.isel(td=slice(1, 3)),
816816
data.sel(td=slice('1 days', '2 days')))
817817

818+
def test_sel_drop(self):
819+
data = Dataset({'foo': ('x', [1, 2, 3])}, {'x': [0, 1, 2]})
820+
expected = Dataset({'foo': 1})
821+
selected = data.sel(x=0, drop=True)
822+
self.assertDatasetIdentical(expected, selected)
823+
824+
expected = Dataset({'foo': 1}, {'x': 0})
825+
selected = data.sel(x=0, drop=False)
826+
self.assertDatasetIdentical(expected, selected)
827+
828+
def test_isel_drop(self):
829+
data = Dataset({'foo': ('x', [1, 2, 3])}, {'x': [0, 1, 2]})
830+
expected = Dataset({'foo': 1})
831+
selected = data.isel(x=0, drop=True)
832+
self.assertDatasetIdentical(expected, selected)
833+
834+
expected = Dataset({'foo': 1}, {'x': 0})
835+
selected = data.isel(x=0, drop=False)
836+
self.assertDatasetIdentical(expected, selected)
837+
818838
def test_isel_points(self):
819839
data = create_test_data()
820840

@@ -1750,6 +1770,16 @@ def get_args(v):
17501770
with self.assertRaisesRegexp(ValueError, 'cannot select a dimension'):
17511771
data.squeeze('y')
17521772

1773+
def test_squeeze_drop(self):
1774+
data = Dataset({'foo': ('x', [1])}, {'x': [0]})
1775+
expected = Dataset({'foo': 1})
1776+
selected = data.squeeze(drop=True)
1777+
self.assertDatasetIdentical(expected, selected)
1778+
1779+
expected = Dataset({'foo': 1}, {'x': 0})
1780+
selected = data.squeeze(drop=False)
1781+
self.assertDatasetIdentical(expected, selected)
1782+
17531783
def test_groupby(self):
17541784
data = Dataset({'z': (['x', 'y'], np.random.randn(3, 5))},
17551785
{'x': ('x', list('abc')),

0 commit comments

Comments
 (0)