Skip to content

Adding arbitrary object serialization #1421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ def maybe_decode_store(store, lock=False):
allow_remote=True)
if engine == 'netcdf4':
store = backends.NetCDF4DataStore(filename_or_obj, group=group,
autoclose=autoclose)
autoclose=autoclose,
allow_object=True)
elif engine == 'scipy':
store = backends.ScipyDataStore(filename_or_obj,
autoclose=autoclose)
Expand Down Expand Up @@ -534,7 +535,8 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT,


def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None,
engine=None, writer=None, encoding=None, unlimited_dims=None):
engine=None, writer=None, encoding=None, unlimited_dims=None,
allow_object=False):
"""This function creates an appropriate datastore for writing a dataset to
disk as a netCDF file

Expand Down Expand Up @@ -574,8 +576,8 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None,
sync = writer is None

target = path_or_file if path_or_file is not None else BytesIO()
store = store_cls(target, mode, format, group, writer)

store = store_cls(target, mode, format, group, writer,
allow_object=allow_object)
if unlimited_dims is None:
unlimited_dims = dataset.encoding.get('unlimited_dims', None)
try:
Expand Down
19 changes: 17 additions & 2 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,25 @@ def sync(self):


class AbstractWritableDataStore(AbstractDataStore):
def __init__(self, writer=None):
def __init__(self, writer=None, allow_object=False):
self.allow_object = allow_object
if writer is None:
writer = ArrayWriter()
self.writer = writer

@property
def allow_object(self):
return self._allow_object

@allow_object.setter
def allow_object(self, value):
if value:
msg = "'{}' does not support native Python object " \
"serialization'".format(self.__class__.__name__)
raise NotImplemented(msg)
else:
self._allow_object = value

def set_dimension(self, d, l): # pragma: no cover
raise NotImplementedError

Expand Down Expand Up @@ -241,7 +255,8 @@ class WritableCFDataStore(AbstractWritableDataStore):
def store(self, variables, attributes, *args, **kwargs):
# All NetCDF files get CF encoded by default, without this attempting
# to write times, for example, would fail.
cf_variables, cf_attrs = cf_encoder(variables, attributes)
cf_variables, cf_attrs = cf_encoder(variables, attributes,
allow_object=self.allow_object)
AbstractWritableDataStore.store(self, cf_variables, cf_attrs,
*args, **kwargs)

Expand Down
4 changes: 2 additions & 2 deletions xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class H5NetCDFStore(WritableCFDataStore, DataStorePickleMixin):
"""Store for reading and writing data via h5netcdf
"""
def __init__(self, filename, mode='r', format=None, group=None,
writer=None, autoclose=False):
writer=None, autoclose=False, allow_object=False):
if format not in [None, 'NETCDF4']:
raise ValueError('invalid format for h5netcdf backend')
opener = functools.partial(_open_h5netcdf_group, filename, mode=mode,
Expand All @@ -71,7 +71,7 @@ def __init__(self, filename, mode='r', format=None, group=None,
self._opener = opener
self._filename = filename
self._mode = mode
super(H5NetCDFStore, self).__init__(writer)
super(H5NetCDFStore, self).__init__(writer, allow_object=allow_object)

def open_store_variable(self, name, var):
with self.ensure_open(autoclose=False):
Expand Down
6 changes: 4 additions & 2 deletions xarray/backends/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ class InMemoryDataStore(AbstractWritableDataStore):

This store exists purely for internal testing purposes.
"""
def __init__(self, variables=None, attributes=None, writer=None):
def __init__(self, variables=None, attributes=None, writer=None,
allow_object=False):
self._variables = OrderedDict() if variables is None else variables
self._attributes = OrderedDict() if attributes is None else attributes
super(InMemoryDataStore, self).__init__(writer)
super(InMemoryDataStore, self).__init__(writer,
allow_object=allow_object)

def get_attrs(self):
return self._attributes
Expand Down
53 changes: 44 additions & 9 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from ..core import indexing
from ..core.utils import (FrozenOrderedDict, NdimSizeLenMixin,
DunderArrayMixin, close_on_error,
is_remote_uri)
is_remote_uri, encode_pickle, decode_pickle)
from ..core.pycompat import iteritems, basestring, OrderedDict, PY3

from .common import (WritableCFDataStore, robust_getitem,
DataStorePickleMixin, find_root)
DataStorePickleMixin, find_root, ArrayWriter)
from .netcdf3 import (encode_nc3_attr_value, encode_nc3_variable,
maybe_convert_to_char_array)

Expand All @@ -41,6 +41,11 @@ def __init__(self, variable_name, datastore):
# represent variable length strings; it also prevents automatic
# string concatenation via conventions.decode_cf_variable
dtype = np.dtype('O')
if dtype == np.uint8 and array.datatype.name == 'object':
self.is_object = True
dtype = np.dtype('O')
else:
self.is_object = False
self.dtype = dtype

def get_array(self):
Expand Down Expand Up @@ -75,6 +80,8 @@ def __getitem__(self, key):
# arrays (slicing them always returns a 1-dimensional array):
# https://github.com/Unidata/netcdf4-python/pull/220
data = np.asscalar(data)
if self.is_object:
data = decode_pickle(data)
return data


Expand Down Expand Up @@ -200,9 +207,14 @@ class NetCDF4DataStore(WritableCFDataStore, DataStorePickleMixin):
"""
def __init__(self, filename, mode='r', format='NETCDF4', group=None,
writer=None, clobber=True, diskless=False, persist=False,
autoclose=False):
autoclose=False, allow_object=False):
if format is None:
format = 'NETCDF4'

assert not allow_object or format.startswith('NETCDF4'), \
"""serializing native Python objects is only possible with
'NETCDF4' format. Current format is '{}'""".format(format)

opener = functools.partial(_open_netcdf4_group, filename, mode=mode,
group=group, clobber=clobber,
diskless=diskless, persist=persist,
Expand All @@ -215,7 +227,22 @@ def __init__(self, filename, mode='r', format='NETCDF4', group=None,
self._filename = filename
self._mode = 'a' if mode == 'w' else mode
self._opener = functools.partial(opener, mode=self._mode)
super(NetCDF4DataStore, self).__init__(writer)
super(NetCDF4DataStore, self).__init__(writer,
allow_object=allow_object)

@WritableCFDataStore.allow_object.setter
def allow_object(self, value):
self._allow_object = value

@property
def _object_datatype(self):
msg = "Object datatype only supported with keyword 'allow_object'"
assert self.allow_object, msg
dtype = self.ds.vltypes.get('object',
self.ds.createVLType(np.uint8, 'object'))
msg = "Object datatype is '{}'. Should be 'uint8'.".format(dtype.dtype)
assert dtype.dtype == np.uint8, msg
return dtype

def open_store_variable(self, name, var):
with self.ensure_open(autoclose=False):
Expand Down Expand Up @@ -287,17 +314,25 @@ def set_variables(self, *args, **kwargs):
with self.ensure_open(autoclose=False):
super(NetCDF4DataStore, self).set_variables(*args, **kwargs)

def _extract_variable_and_datatype(self, variable):
if self.format == 'NETCDF4':
if variable.dtype.kind == 'O' and self.allow_object:
datatype = self._object_datatype
variable.data = encode_pickle(variable)
else:
variable, datatype = _nc4_values_and_dtype(variable)
else:
variable = encode_nc3_variable(variable)
datatype = variable.dtype
return variable, datatype

def prepare_variable(self, name, variable, check_encoding=False,
unlimited_dims=None):
attrs = variable.attrs.copy()

variable = _force_native_endianness(variable)

if self.format == 'NETCDF4':
variable, datatype = _nc4_values_and_dtype(variable)
else:
variable = encode_nc3_variable(variable)
datatype = variable.dtype
variable, datatype = self._extract_variable_and_datatype(variable)

self.set_necessary_dimensions(variable, unlimited_dims=unlimited_dims)

Expand Down
4 changes: 2 additions & 2 deletions xarray/backends/scipy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin):
It only supports the NetCDF3 file-format.
"""
def __init__(self, filename_or_obj, mode='r', format=None, group=None,
writer=None, mmap=None, autoclose=False):
writer=None, mmap=None, autoclose=False, allow_object=False):
import scipy
import scipy.io

Expand Down Expand Up @@ -122,7 +122,7 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None,
self._opener = opener
self._mode = mode

super(ScipyDataStore, self).__init__(writer)
super(ScipyDataStore, self).__init__(writer, allow_object=allow_object)

def open_store_variable(self, name, var):
with self.ensure_open(autoclose=False):
Expand Down
32 changes: 22 additions & 10 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ def maybe_encode_bools(var):
return var


def _infer_dtype(array):
def _infer_dtype(array, allow_object=False):
"""Given an object array with no missing values, infer its dtype from its
first element
"""
Expand All @@ -676,20 +676,24 @@ def _infer_dtype(array):
# the length of their first element
dtype = np.dtype(dtype.kind)
elif dtype.kind == 'O':
raise ValueError('unable to infer dtype; xarray cannot '
'serialize arbitrary Python objects')
if allow_object:
dtype = np.dtype(object)
else:
raise ValueError('unable to infer dtype; xarray cannot '
'serialize arbitrary Python objects')
return dtype


def ensure_dtype_not_object(var):
def maybe_infer_dtype(var, allow_object=False):
# TODO: move this from conventions to backends? (it's not CF related)
if var.dtype.kind == 'O':
dims, data, attrs, encoding = _var_as_tuple(var)
missing = pd.isnull(data)
if missing.any():
# nb. this will fail for dask.array data
non_missing_values = data[~missing]
inferred_dtype = _infer_dtype(non_missing_values)
inferred_dtype = _infer_dtype(non_missing_values,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If _infer_dtype fails, we want to break out of this function and return the original var, not copy the data and put it on a new Variable object (which happens below).

allow_object=allow_object)

if inferred_dtype.kind in ['S', 'U']:
# There is no safe bit-pattern for NA in typical binary string
Expand All @@ -706,12 +710,13 @@ def ensure_dtype_not_object(var):
data = np.array(data, dtype=inferred_dtype, copy=True)
data[missing] = fill_value
else:
data = data.astype(dtype=_infer_dtype(data))
data = data.astype(dtype=_infer_dtype(data,
allow_object=allow_object))
var = Variable(dims, data, attrs, encoding)
return var


def encode_cf_variable(var, needs_copy=True, name=None):
def encode_cf_variable(var, needs_copy=True, name=None, allow_object=False):
"""
Converts an Variable into an Variable which follows some
of the CF conventions:
Expand All @@ -725,6 +730,9 @@ def encode_cf_variable(var, needs_copy=True, name=None):
----------
var : xarray.Variable
A variable holding un-encoded data.
allow_object : bool
Whether to allow objects to pass the encoder or to throw an error if
this is attempted.

Returns
-------
Expand All @@ -738,7 +746,7 @@ def encode_cf_variable(var, needs_copy=True, name=None):
var = maybe_encode_dtype(var, name)
var = maybe_default_fill_value(var)
var = maybe_encode_bools(var)
var = ensure_dtype_not_object(var)
var = maybe_infer_dtype(var, allow_object=allow_object)
return var


Expand Down Expand Up @@ -1058,7 +1066,7 @@ def encode_dataset_coordinates(dataset):
non_dim_coord_names=non_dim_coord_names)


def cf_encoder(variables, attributes):
def cf_encoder(variables, attributes, allow_object=False):
"""
A function which takes a dicts of variables and attributes
and encodes them to conform to CF conventions as much
Expand All @@ -1075,6 +1083,9 @@ def cf_encoder(variables, attributes):
A dictionary mapping from variable name to xarray.Variable
attributes : dict
A dictionary mapping from attribute name to value
allow_object : bool
Whether to allow objects to pass the encoder or to throw an error if
this is attempted.

Returns
-------
Expand All @@ -1085,6 +1096,7 @@ def cf_encoder(variables, attributes):

See also: encode_cf_variable
"""
new_vars = OrderedDict((k, encode_cf_variable(v, name=k))
new_vars = OrderedDict((k, encode_cf_variable(v, name=k,
allow_object=allow_object))
for k, v in iteritems(variables))
return new_vars, attributes
2 changes: 2 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,8 @@ def to_netcdf(self, *args, **kwargs):
Nested dictionary with variable names as keys and dictionaries of
variable specific encodings as values, e.g.,
``{'my_variable': {'dtype': 'int16', 'scale_factor': 0.1, 'zlib': True}, ...}``
allow_object : bool, optional
If True, allow native Python objects to be serialized.

Notes
-----
Expand Down
9 changes: 7 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,7 @@ def reset_coords(self, names=None, drop=False, inplace=False):
def dump_to_store(self, store, encoder=None, sync=True, encoding=None,
unlimited_dims=None):
"""Store dataset contents to a backends.*DataStore object."""

if encoding is None:
encoding = {}
variables, attrs = conventions.encode_dataset_coordinates(self)
Expand All @@ -918,7 +919,8 @@ def dump_to_store(self, store, encoder=None, sync=True, encoding=None,
store.sync()

def to_netcdf(self, path=None, mode='w', format=None, group=None,
engine=None, encoding=None, unlimited_dims=None):
engine=None, encoding=None, unlimited_dims=None,
allow_object=False):
"""Write dataset contents to a netCDF file.

Parameters
Expand Down Expand Up @@ -968,13 +970,16 @@ def to_netcdf(self, path=None, mode='w', format=None, group=None,
By default, no dimensions are treated as unlimited dimensions.
Note that unlimited_dims may also be set via
``dataset.encoding['unlimited_dims']``.
allow_object : bool, optional
If True, allow native Python objects to be serialized.
"""
if encoding is None:
encoding = {}
from ..backends.api import to_netcdf
return to_netcdf(self, path, mode, format=format, group=group,
engine=engine, encoding=encoding,
unlimited_dims=unlimited_dims)
unlimited_dims=unlimited_dims,
allow_object=allow_object)

def __unicode__(self):
return formatting.dataset_repr(self)
Expand Down
11 changes: 11 additions & 0 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import re
import warnings
from collections import Mapping, MutableMapping, Iterable
from six.moves import cPickle as pickle
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xarray doesn't depend on six, so you need to use a try/except here:

try:
    import cPickle as pickle
except ImportError:
    import pickle

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow sorry, I actually forgot it wasn't a builtin, and I lazily didn't set up a proper dev environment. Definitely a simple fix.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this try/except statement should probably go in xarray's pycompat module.


import numpy as np
import pandas as pd
Expand Down Expand Up @@ -489,3 +490,13 @@ def ensure_us_time_resolution(val):
elif np.issubdtype(val.dtype, np.timedelta64):
val = val.astype('timedelta64[us]')
return val


@functools.partial(np.vectorize, otypes='O')
def encode_pickle(obj):
return np.frombuffer(pickle.dumps(obj), dtype=np.uint8)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably want to use a later version of the pickle format -- at least version 2 (which introduced the binary version) if not pickle.HIGHEST_PROTOCOL. Possibly this should be a user controllable argument.

For reference, numpy.save uses protocol=2 and pandas.DataFrame.to_pickle uses HIGHEST_PROTOCOL (which is protocol=2 on Python 2, and currently protocol=4 on Python 3).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think six handles the protocol issue, which is why I didn't do anything here, but no six so we can handle that manually. I don't know much about pickle 2 and 3 compatability (i.e. dump in 2, load in 3), perhaps that would be the nicest configuration to default to?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way that pickle works, any version of Python can load older pickles, but old versions of Python can't load newer pickles. So protocol=2 is a maximally backwards compatible option, but misses out on any later pickle improvements.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's set HIGHEST_PROTOCOL in pycompat as well.



@functools.partial(np.vectorize, otypes='O')
def decode_pickle(obj):
return pickle.loads(obj.tostring())