Skip to content

Commit 05079ae

Browse files
committed
New properties Dataset.sizes and DataArray.sizes
This allows for consistent access to dimension lengths on ``Dataset`` and ``DataArray`` xref #921 (doesn't resolve it 100%, but should help significantly)
1 parent a4f5ec2 commit 05079ae

File tree

6 files changed

+113
-73
lines changed

6 files changed

+113
-73
lines changed

doc/whats-new.rst

+5
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ Enhancements
7777
:py:meth:`~xarray.DataArray.cumprod`. By `Phillip J. Wolfram
7878
<https://github.com/pwolfram>`_.
7979

80+
- New properties :py:attr:`Dataset.sizes` and :py:attr:`DataArray.sizes` for
81+
providing consistent access to dimension length on both ``Dataset`` and
82+
``DataArray`` (:issue:`921`).
83+
By `Stephan Hoyer <https://github.com/shoyer>`_.
84+
8085
Bug fixes
8186
~~~~~~~~~
8287
- ``groupby_bins`` now restores empty bins by default (:issue:`1019`).

xarray/core/common.py

+30-13
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,36 @@ def pipe(self, func, *args, **kwargs):
335335
else:
336336
return func(self, *args, **kwargs)
337337

338+
def squeeze(self, dim=None):
339+
"""Return a new object with squeezed data.
340+
341+
Parameters
342+
----------
343+
dim : None or str or tuple of str, optional
344+
Selects a subset of the length one dimensions. If a dimension is
345+
selected with length greater than one, an error is raised. If
346+
None, all length one dimensions are squeezed.
347+
348+
Returns
349+
-------
350+
squeezed : same type as caller
351+
This object, but with with all or a subset of the dimensions of
352+
length 1 removed.
353+
354+
See Also
355+
--------
356+
numpy.squeeze
357+
"""
358+
if dim is None:
359+
dim = [d for d, s in self.sizes.item() if s == 1]
360+
else:
361+
if isinstance(dim, basestring):
362+
dim = [dim]
363+
if any(self.sizes[k] > 1 for k in dim):
364+
raise ValueError('cannot select a dimension to squeeze out '
365+
'which has length greater than one')
366+
return self.isel(**dict((d, 0) for d in dim))
367+
338368
def groupby(self, group, squeeze=True):
339369
"""Returns a GroupBy object for performing grouped operations.
340370
@@ -615,19 +645,6 @@ def __exit__(self, exc_type, exc_value, traceback):
615645
__or__ = __div__ = __eq__ = __ne__ = not_implemented
616646

617647

618-
def squeeze(xarray_obj, dims, dim=None):
619-
"""Squeeze the dims of an xarray object."""
620-
if dim is None:
621-
dim = [d for d, s in iteritems(dims) if s == 1]
622-
else:
623-
if isinstance(dim, basestring):
624-
dim = [dim]
625-
if any(dims[k] > 1 for k in dim):
626-
raise ValueError('cannot select a dimension to squeeze out '
627-
'which has length greater than one')
628-
return xarray_obj.isel(**dict((d, 0) for d in dim))
629-
630-
631648
def _maybe_promote(dtype):
632649
"""Simpler equivalent of pandas.core.common._maybe_promote"""
633650
# N.B. these casting rules should match pandas

xarray/core/dataarray.py

+45-31
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections
12
import contextlib
23
import functools
34
import warnings
@@ -7,6 +8,7 @@
78

89
from ..plot.plot import _PlotMethods
910

11+
from . import formatting
1012
from . import indexing
1113
from . import groupby
1214
from . import rolling
@@ -72,18 +74,17 @@ def _infer_coords_and_dims(shape, coords, dims):
7274
if dim not in new_coords:
7375
new_coords[dim] = default_index_coordinate(dim, size)
7476

75-
sizes = dict(zip(dims, shape))
7677
for k, v in new_coords.items():
7778
if any(d not in dims for d in v.dims):
7879
raise ValueError('coordinate %s has dimensions %s, but these '
7980
'are not a subset of the DataArray '
8081
'dimensions %s' % (k, v.dims, dims))
8182

8283
for d, s in zip(v.dims, v.shape):
83-
if s != sizes[d]:
84+
if s != self.sizes[d]:
8485
raise ValueError('conflicting sizes for dimension %r: '
8586
'length %s on the data but length %s on '
86-
'coordinate %r' % (d, sizes[d], s, k))
87+
'coordinate %r' % (d, self.sizes[d], s, k))
8788

8889
assert_unique_multiindex_level_names(new_coords)
8990

@@ -110,6 +111,31 @@ def __setitem__(self, key, value):
110111
self.data_array[pos_indexers] = value
111112

112113

114+
class DataArraySizes(collections.Mapping, formatting.ReprMixin):
115+
def __init__(self, array):
116+
self._array = array
117+
118+
def __getitem__(self, key):
119+
try:
120+
index = self._array.dims.index(key)
121+
except ValueError:
122+
raise KeyError(key)
123+
return self._array.shape[index]
124+
125+
def __contains__(self, key):
126+
return key in self._array.dims
127+
128+
def __iter__(self):
129+
return iter(self._array.dims)
130+
131+
def __len__(self):
132+
return len(self._array.dims)
133+
134+
def __unicode__(self):
135+
contents = ', '.join(u'%s: %s' % (k, v) for k, v in self.items())
136+
return u'<%s (%s)>' % (type(self).__name__, contents)
137+
138+
113139
class _ThisArray(object):
114140
"""An instance of this object is used as the key corresponding to the
115141
variable when converting arbitrary DataArray objects to datasets
@@ -411,14 +437,29 @@ def to_index(self):
411437

412438
@property
413439
def dims(self):
414-
"""Dimension names associated with this array."""
440+
"""Tuple of dimension names associated with this array.
441+
442+
Note that the type of this property is inconsistent with `Dataset.dims`.
443+
See `Dataset.sizes` and `DataArray.sizes` for consistently named
444+
properties.
445+
"""
415446
return self.variable.dims
416447

417448
@dims.setter
418449
def dims(self, value):
419450
raise AttributeError('you cannot assign dims on a DataArray. Use '
420451
'.rename() or .swap_dims() instead.')
421452

453+
@property
454+
def sizes(self):
455+
"""Mapping from dimension names to lengths.
456+
457+
See also
458+
--------
459+
Dataset.sizes
460+
"""
461+
return DataArraySizes(self)
462+
422463
def _item_key_to_dict(self, key):
423464
if utils.is_dict_like(key):
424465
return key
@@ -911,33 +952,6 @@ def transpose(self, *dims):
911952
variable = self.variable.transpose(*dims)
912953
return self._replace(variable)
913954

914-
def squeeze(self, dim=None):
915-
"""Return a new DataArray object with squeezed data.
916-
917-
Parameters
918-
----------
919-
dim : None or str or tuple of str, optional
920-
Selects a subset of the length one dimensions. If a dimension is
921-
selected with length greater than one, an error is raised. If
922-
None, all length one dimensions are squeezed.
923-
924-
Returns
925-
-------
926-
squeezed : DataArray
927-
This array, but with with all or a subset of the dimensions of
928-
length 1 removed.
929-
930-
Notes
931-
-----
932-
Although this operation returns a view of this array's data, it is
933-
not lazy -- the data will be fully loaded.
934-
935-
See Also
936-
--------
937-
numpy.squeeze
938-
"""
939-
return squeeze(self, dict(zip(self.dims, self.shape)), dim)
940-
941955
def drop(self, labels, dim=None):
942956
"""Drop coordinates or index labels from this DataArray.
943957

xarray/core/dataset.py

+16-29
Original file line numberDiff line numberDiff line change
@@ -291,11 +291,25 @@ def attrs(self, value):
291291
def dims(self):
292292
"""Mapping from dimension names to lengths.
293293
294-
This dictionary cannot be modified directly, but is updated when adding
295-
new variables.
294+
Cannot be modified directly, but is updated when adding new variables.
295+
296+
Note that type of this object differs from `DataArray.dims`.
297+
See `Dataset.sizes` and `DataArray.sizes` for consistently named
298+
properties.
296299
"""
297300
return Frozen(SortedKeysDict(self._dims))
298301

302+
@property
303+
def sizes(self):
304+
"""Mapping from dimension names to lengths.
305+
306+
Cannot be modified directly, but is updated when adding new variables.
307+
308+
This is an alias for `Dataset.dims` provided for the benefit of
309+
consistency with `DataArray.sizes`.
310+
"""
311+
return self.dims
312+
299313
def load(self):
300314
"""Manually trigger loading of this dataset's data from disk or a
301315
remote source into memory and return this dataset.
@@ -1584,33 +1598,6 @@ def transpose(self, *dims):
15841598
def T(self):
15851599
return self.transpose()
15861600

1587-
def squeeze(self, dim=None):
1588-
"""Returns a new dataset with squeezed data.
1589-
1590-
Parameters
1591-
----------
1592-
dim : None or str or tuple of str, optional
1593-
Selects a subset of the length one dimensions. If a dimension is
1594-
selected with length greater than one, an error is raised. If
1595-
None, all length one dimensions are squeezed.
1596-
1597-
Returns
1598-
-------
1599-
squeezed : Dataset
1600-
This dataset, but with with all or a subset of the dimensions of
1601-
length 1 removed.
1602-
1603-
Notes
1604-
-----
1605-
Although this operation returns a view of each variable's data, it is
1606-
not lazy -- all variable data will be fully loaded.
1607-
1608-
See Also
1609-
--------
1610-
numpy.squeeze
1611-
"""
1612-
return common.squeeze(self, self.dims, dim)
1613-
16141601
def dropna(self, dim, how='any', thresh=None, subset=None):
16151602
"""Returns a new dataset with dropped labels for missing values along
16161603
the provided dimension.

xarray/test/test_dataarray.py

+16
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,22 @@ def test_dims(self):
159159
with self.assertRaisesRegexp(AttributeError, 'you cannot assign'):
160160
arr.dims = ('w', 'z')
161161

162+
def test_sizes(self):
163+
array = DataArray(np.zeros((3, 4)), dims=['x', 'y'])
164+
self.assertEqual(array.sizes, {'x': 3, 'y': 4})
165+
166+
with self.assertRaisesRegexp(KeyError, repr('foo')):
167+
array.sizes['foo']
168+
self.assertEqual(array.sizes['x'], 3)
169+
self.assertEqual(array.sizes['y'], 4)
170+
171+
self.assertIn('x', array.sizes)
172+
self.assertNotIn('foo', array.sizes)
173+
174+
self.assertEqual(tuple(array.sizes), array.dims)
175+
self.assertEqual(len(array.sizes), 2)
176+
self.assertEqual(repr(array.sizes), u'<DataArraySizes (x: 3, y: 4)>')
177+
162178
def test_encoding(self):
163179
expected = {'foo': 'bar'}
164180
self.dv.encoding['foo'] = 'bar'

xarray/test/test_dataset.py

+1
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ def test_properties(self):
338338
self.assertEqual(ds.dims,
339339
{'dim1': 8, 'dim2': 9, 'dim3': 10, 'time': 20})
340340
self.assertEqual(list(ds.dims), sorted(ds.dims))
341+
self.assertEqual(ds.sizes, ds.dims)
341342

342343
# These exact types aren't public API, but this makes sure we don't
343344
# change them inadvertently:

0 commit comments

Comments
 (0)