diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 4929da2539c..babec1beb4b 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -288,7 +288,8 @@ def maybe_decode_store(store, lock=False): allow_remote=True) if engine == 'netcdf4': store = backends.NetCDF4DataStore(filename_or_obj, group=group, - autoclose=autoclose) + autoclose=autoclose, + allow_object=True) elif engine == 'scipy': store = backends.ScipyDataStore(filename_or_obj, autoclose=autoclose) @@ -534,7 +535,8 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, - engine=None, writer=None, encoding=None, unlimited_dims=None): + engine=None, writer=None, encoding=None, unlimited_dims=None, + allow_object=False): """This function creates an appropriate datastore for writing a dataset to disk as a netCDF file @@ -574,8 +576,8 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, sync = writer is None target = path_or_file if path_or_file is not None else BytesIO() - store = store_cls(target, mode, format, group, writer) - + store = store_cls(target, mode, format, group, writer, + allow_object=allow_object) if unlimited_dims is None: unlimited_dims = dataset.encoding.get('unlimited_dims', None) try: diff --git a/xarray/backends/common.py b/xarray/backends/common.py index cec55d22589..caf9b64add0 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -184,11 +184,25 @@ def sync(self): class AbstractWritableDataStore(AbstractDataStore): - def __init__(self, writer=None): + def __init__(self, writer=None, allow_object=False): + self.allow_object = allow_object if writer is None: writer = ArrayWriter() self.writer = writer + @property + def allow_object(self): + return self._allow_object + + @allow_object.setter + def allow_object(self, value): + if value: + msg = "'{}' does not support native Python object " \ + "serialization'".format(self.__class__.__name__) + raise NotImplemented(msg) + else: + self._allow_object = value + def set_dimension(self, d, l): # pragma: no cover raise NotImplementedError @@ -241,7 +255,8 @@ class WritableCFDataStore(AbstractWritableDataStore): def store(self, variables, attributes, *args, **kwargs): # All NetCDF files get CF encoded by default, without this attempting # to write times, for example, would fail. - cf_variables, cf_attrs = cf_encoder(variables, attributes) + cf_variables, cf_attrs = cf_encoder(variables, attributes, + allow_object=self.allow_object) AbstractWritableDataStore.store(self, cf_variables, cf_attrs, *args, **kwargs) diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 0e7d866bc2a..d1f15134833 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -55,7 +55,7 @@ class H5NetCDFStore(WritableCFDataStore, DataStorePickleMixin): """Store for reading and writing data via h5netcdf """ def __init__(self, filename, mode='r', format=None, group=None, - writer=None, autoclose=False): + writer=None, autoclose=False, allow_object=False): if format not in [None, 'NETCDF4']: raise ValueError('invalid format for h5netcdf backend') opener = functools.partial(_open_h5netcdf_group, filename, mode=mode, @@ -71,7 +71,7 @@ def __init__(self, filename, mode='r', format=None, group=None, self._opener = opener self._filename = filename self._mode = mode - super(H5NetCDFStore, self).__init__(writer) + super(H5NetCDFStore, self).__init__(writer, allow_object=allow_object) def open_store_variable(self, name, var): with self.ensure_open(autoclose=False): diff --git a/xarray/backends/memory.py b/xarray/backends/memory.py index f79e92439fe..243ef743745 100644 --- a/xarray/backends/memory.py +++ b/xarray/backends/memory.py @@ -18,10 +18,12 @@ class InMemoryDataStore(AbstractWritableDataStore): This store exists purely for internal testing purposes. """ - def __init__(self, variables=None, attributes=None, writer=None): + def __init__(self, variables=None, attributes=None, writer=None, + allow_object=False): self._variables = OrderedDict() if variables is None else variables self._attributes = OrderedDict() if attributes is None else attributes - super(InMemoryDataStore, self).__init__(writer) + super(InMemoryDataStore, self).__init__(writer, + allow_object=allow_object) def get_attrs(self): return self._attributes diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 93af50f4ae5..4174a872585 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -11,11 +11,11 @@ from ..core import indexing from ..core.utils import (FrozenOrderedDict, NdimSizeLenMixin, DunderArrayMixin, close_on_error, - is_remote_uri) + is_remote_uri, encode_pickle, decode_pickle) from ..core.pycompat import iteritems, basestring, OrderedDict, PY3 from .common import (WritableCFDataStore, robust_getitem, - DataStorePickleMixin, find_root) + DataStorePickleMixin, find_root, ArrayWriter) from .netcdf3 import (encode_nc3_attr_value, encode_nc3_variable, maybe_convert_to_char_array) @@ -41,6 +41,11 @@ def __init__(self, variable_name, datastore): # represent variable length strings; it also prevents automatic # string concatenation via conventions.decode_cf_variable dtype = np.dtype('O') + if dtype == np.uint8 and array.datatype.name == 'object': + self.is_object = True + dtype = np.dtype('O') + else: + self.is_object = False self.dtype = dtype def get_array(self): @@ -75,6 +80,8 @@ def __getitem__(self, key): # arrays (slicing them always returns a 1-dimensional array): # https://github.com/Unidata/netcdf4-python/pull/220 data = np.asscalar(data) + if self.is_object: + data = decode_pickle(data) return data @@ -200,9 +207,14 @@ class NetCDF4DataStore(WritableCFDataStore, DataStorePickleMixin): """ def __init__(self, filename, mode='r', format='NETCDF4', group=None, writer=None, clobber=True, diskless=False, persist=False, - autoclose=False): + autoclose=False, allow_object=False): if format is None: format = 'NETCDF4' + + assert not allow_object or format.startswith('NETCDF4'), \ + """serializing native Python objects is only possible with + 'NETCDF4' format. Current format is '{}'""".format(format) + opener = functools.partial(_open_netcdf4_group, filename, mode=mode, group=group, clobber=clobber, diskless=diskless, persist=persist, @@ -215,7 +227,22 @@ def __init__(self, filename, mode='r', format='NETCDF4', group=None, self._filename = filename self._mode = 'a' if mode == 'w' else mode self._opener = functools.partial(opener, mode=self._mode) - super(NetCDF4DataStore, self).__init__(writer) + super(NetCDF4DataStore, self).__init__(writer, + allow_object=allow_object) + + @WritableCFDataStore.allow_object.setter + def allow_object(self, value): + self._allow_object = value + + @property + def _object_datatype(self): + msg = "Object datatype only supported with keyword 'allow_object'" + assert self.allow_object, msg + dtype = self.ds.vltypes.get('object', + self.ds.createVLType(np.uint8, 'object')) + msg = "Object datatype is '{}'. Should be 'uint8'.".format(dtype.dtype) + assert dtype.dtype == np.uint8, msg + return dtype def open_store_variable(self, name, var): with self.ensure_open(autoclose=False): @@ -287,17 +314,25 @@ def set_variables(self, *args, **kwargs): with self.ensure_open(autoclose=False): super(NetCDF4DataStore, self).set_variables(*args, **kwargs) + def _extract_variable_and_datatype(self, variable): + if self.format == 'NETCDF4': + if variable.dtype.kind == 'O' and self.allow_object: + datatype = self._object_datatype + variable.data = encode_pickle(variable) + else: + variable, datatype = _nc4_values_and_dtype(variable) + else: + variable = encode_nc3_variable(variable) + datatype = variable.dtype + return variable, datatype + def prepare_variable(self, name, variable, check_encoding=False, unlimited_dims=None): attrs = variable.attrs.copy() variable = _force_native_endianness(variable) - if self.format == 'NETCDF4': - variable, datatype = _nc4_values_and_dtype(variable) - else: - variable = encode_nc3_variable(variable) - datatype = variable.dtype + variable, datatype = self._extract_variable_and_datatype(variable) self.set_necessary_dimensions(variable, unlimited_dims=unlimited_dims) diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index ab896043db7..18d003d261e 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -90,7 +90,7 @@ class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin): It only supports the NetCDF3 file-format. """ def __init__(self, filename_or_obj, mode='r', format=None, group=None, - writer=None, mmap=None, autoclose=False): + writer=None, mmap=None, autoclose=False, allow_object=False): import scipy import scipy.io @@ -122,7 +122,7 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, self._opener = opener self._mode = mode - super(ScipyDataStore, self).__init__(writer) + super(ScipyDataStore, self).__init__(writer, allow_object=allow_object) def open_store_variable(self, name, var): with self.ensure_open(autoclose=False): diff --git a/xarray/conventions.py b/xarray/conventions.py index d39ae20925a..42965f8a629 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -663,7 +663,7 @@ def maybe_encode_bools(var): return var -def _infer_dtype(array): +def _infer_dtype(array, allow_object=False): """Given an object array with no missing values, infer its dtype from its first element """ @@ -676,12 +676,15 @@ def _infer_dtype(array): # the length of their first element dtype = np.dtype(dtype.kind) elif dtype.kind == 'O': - raise ValueError('unable to infer dtype; xarray cannot ' - 'serialize arbitrary Python objects') + if allow_object: + dtype = np.dtype(object) + else: + raise ValueError('unable to infer dtype; xarray cannot ' + 'serialize arbitrary Python objects') return dtype -def ensure_dtype_not_object(var): +def maybe_infer_dtype(var, allow_object=False): # TODO: move this from conventions to backends? (it's not CF related) if var.dtype.kind == 'O': dims, data, attrs, encoding = _var_as_tuple(var) @@ -689,7 +692,8 @@ def ensure_dtype_not_object(var): if missing.any(): # nb. this will fail for dask.array data non_missing_values = data[~missing] - inferred_dtype = _infer_dtype(non_missing_values) + inferred_dtype = _infer_dtype(non_missing_values, + allow_object=allow_object) if inferred_dtype.kind in ['S', 'U']: # There is no safe bit-pattern for NA in typical binary string @@ -706,12 +710,13 @@ def ensure_dtype_not_object(var): data = np.array(data, dtype=inferred_dtype, copy=True) data[missing] = fill_value else: - data = data.astype(dtype=_infer_dtype(data)) + data = data.astype(dtype=_infer_dtype(data, + allow_object=allow_object)) var = Variable(dims, data, attrs, encoding) return var -def encode_cf_variable(var, needs_copy=True, name=None): +def encode_cf_variable(var, needs_copy=True, name=None, allow_object=False): """ Converts an Variable into an Variable which follows some of the CF conventions: @@ -725,6 +730,9 @@ def encode_cf_variable(var, needs_copy=True, name=None): ---------- var : xarray.Variable A variable holding un-encoded data. + allow_object : bool + Whether to allow objects to pass the encoder or to throw an error if + this is attempted. Returns ------- @@ -738,7 +746,7 @@ def encode_cf_variable(var, needs_copy=True, name=None): var = maybe_encode_dtype(var, name) var = maybe_default_fill_value(var) var = maybe_encode_bools(var) - var = ensure_dtype_not_object(var) + var = maybe_infer_dtype(var, allow_object=allow_object) return var @@ -1058,7 +1066,7 @@ def encode_dataset_coordinates(dataset): non_dim_coord_names=non_dim_coord_names) -def cf_encoder(variables, attributes): +def cf_encoder(variables, attributes, allow_object=False): """ A function which takes a dicts of variables and attributes and encodes them to conform to CF conventions as much @@ -1075,6 +1083,9 @@ def cf_encoder(variables, attributes): A dictionary mapping from variable name to xarray.Variable attributes : dict A dictionary mapping from attribute name to value + allow_object : bool + Whether to allow objects to pass the encoder or to throw an error if + this is attempted. Returns ------- @@ -1085,6 +1096,7 @@ def cf_encoder(variables, attributes): See also: encode_cf_variable """ - new_vars = OrderedDict((k, encode_cf_variable(v, name=k)) + new_vars = OrderedDict((k, encode_cf_variable(v, name=k, + allow_object=allow_object)) for k, v in iteritems(variables)) return new_vars, attributes diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c12983c20a2..c8a9fe0feaf 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1314,6 +1314,8 @@ def to_netcdf(self, *args, **kwargs): Nested dictionary with variable names as keys and dictionaries of variable specific encodings as values, e.g., ``{'my_variable': {'dtype': 'int16', 'scale_factor': 0.1, 'zlib': True}, ...}`` + allow_object : bool, optional + If True, allow native Python objects to be serialized. Notes ----- diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index ae5499a46a7..43fd1838bbb 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -898,6 +898,7 @@ def reset_coords(self, names=None, drop=False, inplace=False): def dump_to_store(self, store, encoder=None, sync=True, encoding=None, unlimited_dims=None): """Store dataset contents to a backends.*DataStore object.""" + if encoding is None: encoding = {} variables, attrs = conventions.encode_dataset_coordinates(self) @@ -918,7 +919,8 @@ def dump_to_store(self, store, encoder=None, sync=True, encoding=None, store.sync() def to_netcdf(self, path=None, mode='w', format=None, group=None, - engine=None, encoding=None, unlimited_dims=None): + engine=None, encoding=None, unlimited_dims=None, + allow_object=False): """Write dataset contents to a netCDF file. Parameters @@ -968,13 +970,16 @@ def to_netcdf(self, path=None, mode='w', format=None, group=None, By default, no dimensions are treated as unlimited dimensions. Note that unlimited_dims may also be set via ``dataset.encoding['unlimited_dims']``. + allow_object : bool, optional + If True, allow native Python objects to be serialized. """ if encoding is None: encoding = {} from ..backends.api import to_netcdf return to_netcdf(self, path, mode, format=format, group=group, engine=engine, encoding=encoding, - unlimited_dims=unlimited_dims) + unlimited_dims=unlimited_dims, + allow_object=allow_object) def __unicode__(self): return formatting.dataset_repr(self) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 89d1462328c..702f2c56bd1 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -9,6 +9,7 @@ import re import warnings from collections import Mapping, MutableMapping, Iterable +from six.moves import cPickle as pickle import numpy as np import pandas as pd @@ -489,3 +490,13 @@ def ensure_us_time_resolution(val): elif np.issubdtype(val.dtype, np.timedelta64): val = val.astype('timedelta64[us]') return val + + +@functools.partial(np.vectorize, otypes='O') +def encode_pickle(obj): + return np.frombuffer(pickle.dumps(obj), dtype=np.uint8) + + +@functools.partial(np.vectorize, otypes='O') +def decode_pickle(obj): + return pickle.loads(obj.tostring())