Skip to content

Commit 293922e

Browse files
jsignellshoyer
authored andcommitted
catch numpy arrays in attrs before converting to dict (#1052)
* catch np.arrays in attrs before converting to dict * pushed time_check and np_check out to utils.py * changed utils funtion names * added numpy scalar test
1 parent f0c7203 commit 293922e

File tree

5 files changed

+88
-35
lines changed

5 files changed

+88
-35
lines changed

xarray/core/dataarray.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
default_index_coordinate,
2323
assert_unique_multiindex_level_names)
2424
from .formatting import format_item
25+
from .utils import decode_numpy_dict_values, ensure_us_time_resolution
2526

2627

2728
def _infer_coords_and_dims(shape, coords, dims):
@@ -1194,29 +1195,25 @@ def to_dict(self):
11941195
Convert this xarray.DataArray into a dictionary following xarray
11951196
naming conventions.
11961197
1197-
Useful for coverting to json.
1198+
Converts all variables and attributes to native Python objects.
1199+
Useful for coverting to json. To avoid datetime incompatibility
1200+
use decode_times=False kwarg in xarrray.open_dataset.
11981201
11991202
See also
12001203
--------
12011204
xarray.DataArray.from_dict
12021205
"""
1203-
d = {'coords': {}, 'attrs': dict(self.attrs), 'dims': self.dims}
1204-
1205-
def time_check(val):
1206-
# needed because of numpy bug GH#7619
1207-
if np.issubdtype(val.dtype, np.datetime64):
1208-
val = val.astype('datetime64[us]')
1209-
elif np.issubdtype(val.dtype, np.timedelta64):
1210-
val = val.astype('timedelta64[us]')
1211-
return val
1206+
d = {'coords': {}, 'attrs': decode_numpy_dict_values(self.attrs),
1207+
'dims': self.dims}
12121208

12131209
for k in self.coords:
1214-
data = time_check(self[k].values).tolist()
1215-
d['coords'].update({k: {'data': data,
1216-
'dims': self[k].dims,
1217-
'attrs': dict(self[k].attrs)}})
1210+
data = ensure_us_time_resolution(self[k].values).tolist()
1211+
d['coords'].update({
1212+
k: {'data': data,
1213+
'dims': self[k].dims,
1214+
'attrs': decode_numpy_dict_values(self[k].attrs)}})
12181215

1219-
d.update({'data': time_check(self.values).tolist(),
1216+
d.update({'data': ensure_us_time_resolution(self.values).tolist(),
12201217
'name': self.name})
12211218
return d
12221219

xarray/core/dataset.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from .common import ImplementsDatasetReduce, BaseDataObject
2020
from .merge import (dataset_update_method, dataset_merge_method,
2121
merge_data_and_coords)
22-
from .utils import Frozen, SortedKeysDict, maybe_wrap_array, hashable
22+
from .utils import (Frozen, SortedKeysDict, maybe_wrap_array, hashable,
23+
decode_numpy_dict_values, ensure_us_time_resolution)
2324
from .variable import (Variable, as_variable, IndexVariable, broadcast_variables)
2425
from .pycompat import (iteritems, basestring, OrderedDict,
2526
dask_array_type)
@@ -1917,33 +1918,29 @@ def to_dict(self):
19171918
Convert this dataset to a dictionary following xarray naming
19181919
conventions.
19191920
1920-
Useful for coverting to json.
1921+
Converts all variables and attributes to native Python objects
1922+
Useful for coverting to json. To avoid datetime incompatibility
1923+
use decode_times=False kwarg in xarrray.open_dataset.
19211924
19221925
See also
19231926
--------
19241927
xarray.Dataset.from_dict
19251928
"""
1926-
d = {'coords': {}, 'attrs': dict(self.attrs), 'dims': dict(self.dims),
1927-
'data_vars': {}}
1928-
1929-
def time_check(val):
1930-
# needed because of numpy bug GH#7619
1931-
if np.issubdtype(val.dtype, np.datetime64):
1932-
val = val.astype('datetime64[us]')
1933-
elif np.issubdtype(val.dtype, np.timedelta64):
1934-
val = val.astype('timedelta64[us]')
1935-
return val
1929+
d = {'coords': {}, 'attrs': decode_numpy_dict_values(self.attrs),
1930+
'dims': dict(self.dims), 'data_vars': {}}
19361931

19371932
for k in self.coords:
1938-
data = time_check(self[k].values).tolist()
1939-
d['coords'].update({k: {'data': data,
1940-
'dims': self[k].dims,
1941-
'attrs': dict(self[k].attrs)}})
1933+
data = ensure_us_time_resolution(self[k].values).tolist()
1934+
d['coords'].update({
1935+
k: {'data': data,
1936+
'dims': self[k].dims,
1937+
'attrs': decode_numpy_dict_values(self[k].attrs)}})
19421938
for k in self.data_vars:
1943-
data = time_check(self[k].values).tolist()
1944-
d['data_vars'].update({k: {'data': data,
1945-
'dims': self[k].dims,
1946-
'attrs': dict(self[k].attrs)}})
1939+
data = ensure_us_time_resolution(self[k].values).tolist()
1940+
d['data_vars'].update({
1941+
k: {'data': data,
1942+
'dims': self[k].dims,
1943+
'attrs': decode_numpy_dict_values(self[k].attrs)}})
19471944
return d
19481945

19491946
@classmethod

xarray/core/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,3 +460,25 @@ def hashable(v):
460460

461461
def not_implemented(*args, **kwargs):
462462
return NotImplemented
463+
464+
465+
def decode_numpy_dict_values(attrs):
466+
"""Convert attribute values from numpy objects to native Python objects,
467+
for use in to_dict"""
468+
attrs = dict(attrs)
469+
for k, v in attrs.items():
470+
if isinstance(v, np.ndarray):
471+
attrs[k] = v.tolist()
472+
elif isinstance(v, np.generic):
473+
attrs[k] = np.asscalar(v)
474+
return attrs
475+
476+
477+
def ensure_us_time_resolution(val):
478+
"""Convert val out of numpy time, for use in to_dict.
479+
Needed because of numpy bug GH#7619"""
480+
if np.issubdtype(val.dtype, np.datetime64):
481+
val = val.astype('datetime64[us]')
482+
elif np.issubdtype(val.dtype, np.timedelta64):
483+
val = val.astype('timedelta64[us]')
484+
return val

xarray/test/test_dataarray.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,6 +1989,24 @@ def test_to_and_from_dict_with_nan_nat(self):
19891989
roundtripped = DataArray.from_dict(da.to_dict())
19901990
self.assertDataArrayIdentical(da, roundtripped)
19911991

1992+
def test_to_dict_with_numpy_attrs(self):
1993+
# this doesn't need to roundtrip
1994+
x = np.random.randn(10, 3)
1995+
t = list('abcdefghij')
1996+
lat = [77.7, 83.2, 76]
1997+
attrs = {'created': np.float64(1998),
1998+
'coords': np.array([37, -110.1, 100]),
1999+
'maintainer': 'bar'}
2000+
da = DataArray(x, {'t': t, 'lat': lat}, dims=['t', 'lat'],
2001+
attrs=attrs)
2002+
expected_attrs = {'created': np.asscalar(attrs['created']),
2003+
'coords': attrs['coords'].tolist(),
2004+
'maintainer': 'bar'}
2005+
actual = da.to_dict()
2006+
2007+
# check that they are identical
2008+
self.assertEqual(expected_attrs, actual['attrs'])
2009+
19922010
def test_to_masked_array(self):
19932011
rs = np.random.RandomState(44)
19942012
x = rs.random_sample(size=(10, 20))

xarray/test/test_dataset.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2120,6 +2120,25 @@ def test_to_and_from_dict_with_nan_nat(self):
21202120
roundtripped = Dataset.from_dict(ds.to_dict())
21212121
self.assertDatasetIdentical(ds, roundtripped)
21222122

2123+
def test_to_dict_with_numpy_attrs(self):
2124+
# this doesn't need to roundtrip
2125+
x = np.random.randn(10)
2126+
y = np.random.randn(10)
2127+
t = list('abcdefghij')
2128+
attrs = {'created': np.float64(1998),
2129+
'coords': np.array([37, -110.1, 100]),
2130+
'maintainer': 'bar'}
2131+
ds = Dataset(OrderedDict([('a', ('t', x, attrs)),
2132+
('b', ('t', y, attrs)),
2133+
('t', ('t', t))]))
2134+
expected_attrs = {'created': np.asscalar(attrs['created']),
2135+
'coords': attrs['coords'].tolist(),
2136+
'maintainer': 'bar'}
2137+
actual = ds.to_dict()
2138+
2139+
# check that they are identical
2140+
self.assertEqual(expected_attrs, actual['data_vars']['a']['attrs'])
2141+
21232142
def test_pickle(self):
21242143
data = create_test_data()
21252144
roundtripped = pickle.loads(pickle.dumps(data))

0 commit comments

Comments
 (0)