Skip to content

Commit 4106b94

Browse files
authored
Fix dtype=S1 encoding in to_netcdf() (#2158)
* Fix dtype=S1 encoding in to_netcdf() Fixes GH2149 * Add test_encoding_kwarg_compression from crusaderky * Fix dtype=S1 in kwargs for bytes, too * Fix lint * Move compression encoding kwarg test * Remvoe no longer relevant chanegs * Fix encoding dtype=str * More lint * Fix failed tests * Review comments * oops, we still need to skip that test * check for presence in a tuple rather than making two comparisons
1 parent 9d60897 commit 4106b94

File tree

7 files changed

+114
-18
lines changed

7 files changed

+114
-18
lines changed

doc/whats-new.rst

+5
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ Enhancements
7171
Bug fixes
7272
~~~~~~~~~
7373

74+
- Fixed a regression in 0.10.4, where explicitly specifying ``dtype='S1'`` or
75+
``dtype=str`` in ``encoding`` with ``to_netcdf()`` raised an error
76+
(:issue:`2149`).
77+
`Stephan Hoyer <https://github.com/shoyer>`_
78+
7479
- :py:func:`apply_ufunc` now directly validates output variables
7580
(:issue:`1931`).
7681
By `Stephan Hoyer <https://github.com/shoyer>`_.

xarray/backends/h5netcdf_.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def __init__(self, filename, mode='r', format=None, group=None,
9494
super(H5NetCDFStore, self).__init__(writer, lock=lock)
9595

9696
def open_store_variable(self, name, var):
97+
import h5py
98+
9799
with self.ensure_open(autoclose=False):
98100
dimensions = var.dimensions
99101
data = indexing.LazilyOuterIndexedArray(
@@ -119,6 +121,15 @@ def open_store_variable(self, name, var):
119121
encoding['source'] = self._filename
120122
encoding['original_shape'] = var.shape
121123

124+
vlen_dtype = h5py.check_dtype(vlen=var.dtype)
125+
if vlen_dtype is unicode_type:
126+
encoding['dtype'] = str
127+
elif vlen_dtype is not None: # pragma: no cover
128+
# xarray doesn't support writing arbitrary vlen dtypes yet.
129+
pass
130+
else:
131+
encoding['dtype'] = var.dtype
132+
122133
return Variable(dimensions, data, attrs, encoding)
123134

124135
def get_variables(self):
@@ -161,7 +172,8 @@ def prepare_variable(self, name, variable, check_encoding=False,
161172
import h5py
162173

163174
attrs = variable.attrs.copy()
164-
dtype = _get_datatype(variable)
175+
dtype = _get_datatype(
176+
variable, raise_on_invalid_encoding=check_encoding)
165177

166178
fillvalue = attrs.pop('_FillValue', None)
167179
if dtype is str and fillvalue is not None:
@@ -189,8 +201,9 @@ def prepare_variable(self, name, variable, check_encoding=False,
189201
raise ValueError("'zlib' and 'compression' encodings mismatch")
190202
encoding.setdefault('compression', 'gzip')
191203

192-
if (check_encoding and encoding.get('complevel') not in
193-
(None, encoding.get('compression_opts'))):
204+
if (check_encoding and
205+
'complevel' in encoding and 'compression_opts' in encoding and
206+
encoding['complevel'] != encoding['compression_opts']):
194207
raise ValueError("'complevel' and 'compression_opts' encodings "
195208
"mismatch")
196209
complevel = encoding.pop('complevel', 0)

xarray/backends/netCDF4_.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,33 @@ def _encode_nc4_variable(var):
8989
return var
9090

9191

92-
def _get_datatype(var, nc_format='NETCDF4'):
92+
def _check_encoding_dtype_is_vlen_string(dtype):
93+
if dtype is not str:
94+
raise AssertionError( # pragma: no cover
95+
"unexpected dtype encoding %r. This shouldn't happen: please "
96+
"file a bug report at github.com/pydata/xarray" % dtype)
97+
98+
99+
def _get_datatype(var, nc_format='NETCDF4', raise_on_invalid_encoding=False):
93100
if nc_format == 'NETCDF4':
94101
datatype = _nc4_dtype(var)
95102
else:
103+
if 'dtype' in var.encoding:
104+
encoded_dtype = var.encoding['dtype']
105+
_check_encoding_dtype_is_vlen_string(encoded_dtype)
106+
if raise_on_invalid_encoding:
107+
raise ValueError(
108+
'encoding dtype=str for vlen strings is only supported '
109+
'with format=\'NETCDF4\'.')
96110
datatype = var.dtype
97111
return datatype
98112

99113

100114
def _nc4_dtype(var):
101-
if coding.strings.is_unicode_dtype(var.dtype):
115+
if 'dtype' in var.encoding:
116+
dtype = var.encoding.pop('dtype')
117+
_check_encoding_dtype_is_vlen_string(dtype)
118+
elif coding.strings.is_unicode_dtype(var.dtype):
102119
dtype = str
103120
elif var.dtype.kind in ['i', 'u', 'f', 'c', 'S']:
104121
dtype = var.dtype
@@ -172,7 +189,7 @@ def _extract_nc4_variable_encoding(variable, raise_on_invalid=False,
172189

173190
safe_to_drop = set(['source', 'original_shape'])
174191
valid_encodings = set(['zlib', 'complevel', 'fletcher32', 'contiguous',
175-
'chunksizes', 'shuffle', '_FillValue'])
192+
'chunksizes', 'shuffle', '_FillValue', 'dtype'])
176193
if lsd_okay:
177194
valid_encodings.add('least_significant_digit')
178195
if h5py_okay:
@@ -344,6 +361,7 @@ def open_store_variable(self, name, var):
344361
# save source so __repr__ can detect if it's local or not
345362
encoding['source'] = self._filename
346363
encoding['original_shape'] = var.shape
364+
encoding['dtype'] = var.dtype
347365

348366
return Variable(dimensions, data, attributes, encoding)
349367

@@ -398,7 +416,8 @@ def encode_variable(self, variable):
398416

399417
def prepare_variable(self, name, variable, check_encoding=False,
400418
unlimited_dims=None):
401-
datatype = _get_datatype(variable, self.format)
419+
datatype = _get_datatype(variable, self.format,
420+
raise_on_invalid_encoding=check_encoding)
402421
attrs = variable.attrs.copy()
403422

404423
fill_value = attrs.pop('_FillValue', None)

xarray/coding/strings.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ def encode(self, variable, name=None):
4343
dims, data, attrs, encoding = unpack_for_encoding(variable)
4444

4545
contains_unicode = is_unicode_dtype(data.dtype)
46-
encode_as_char = 'dtype' in encoding and encoding['dtype'] == 'S1'
46+
encode_as_char = encoding.get('dtype') == 'S1'
47+
48+
if encode_as_char:
49+
del encoding['dtype'] # no longer relevant
4750

4851
if contains_unicode and (encode_as_char or not self.allows_unicode):
4952
if '_FillValue' in attrs:
@@ -100,7 +103,7 @@ def encode(self, variable, name=None):
100103
variable = ensure_fixed_length_bytes(variable)
101104

102105
dims, data, attrs, encoding = unpack_for_encoding(variable)
103-
if data.dtype.kind == 'S':
106+
if data.dtype.kind == 'S' and encoding.get('dtype') is not str:
104107
data = bytes_to_char(data)
105108
dims = dims + ('string%s' % data.shape[-1],)
106109
return Variable(dims, data, attrs, encoding)

xarray/conventions.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def _var_as_tuple(var):
7979

8080

8181
def maybe_encode_nonstring_dtype(var, name=None):
82-
if 'dtype' in var.encoding and var.encoding['dtype'] != 'S1':
82+
if ('dtype' in var.encoding and
83+
var.encoding['dtype'] not in ('S1', str)):
8384
dims, data, attrs, encoding = _var_as_tuple(var)
8485
dtype = np.dtype(encoding.pop('dtype'))
8586
if dtype != var.dtype:
@@ -307,12 +308,7 @@ def decode_cf_variable(name, var, concat_characters=True, mask_and_scale=True,
307308
data = NativeEndiannessArray(data)
308309
original_dtype = data.dtype
309310

310-
if 'dtype' in encoding:
311-
if original_dtype != encoding['dtype']:
312-
warnings.warn("CF decoding is overwriting dtype on variable {!r}"
313-
.format(name))
314-
else:
315-
encoding['dtype'] = original_dtype
311+
encoding.setdefault('dtype', original_dtype)
316312

317313
if 'dtype' in attributes and attributes['dtype'] == 'bool':
318314
del attributes['dtype']

xarray/tests/test_backends.py

+55-2
Original file line numberDiff line numberDiff line change
@@ -753,13 +753,26 @@ def test_encoding_kwarg(self):
753753
with self.roundtrip(ds, save_kwargs=kwargs) as actual:
754754
pass
755755

756+
def test_encoding_kwarg_dates(self):
756757
ds = Dataset({'t': pd.date_range('2000-01-01', periods=3)})
757758
units = 'days since 1900-01-01'
758759
kwargs = dict(encoding={'t': {'units': units}})
759760
with self.roundtrip(ds, save_kwargs=kwargs) as actual:
760761
self.assertEqual(actual.t.encoding['units'], units)
761762
assert_identical(actual, ds)
762763

764+
def test_encoding_kwarg_fixed_width_string(self):
765+
# regression test for GH2149
766+
for strings in [
767+
[b'foo', b'bar', b'baz'],
768+
[u'foo', u'bar', u'baz'],
769+
]:
770+
ds = Dataset({'x': strings})
771+
kwargs = dict(encoding={'x': {'dtype': 'S1'}})
772+
with self.roundtrip(ds, save_kwargs=kwargs) as actual:
773+
self.assertEqual(actual['x'].encoding['dtype'], 'S1')
774+
assert_identical(actual, ds)
775+
763776
def test_default_fill_value(self):
764777
# Test default encoding for float:
765778
ds = Dataset({'x': ('y', np.arange(10.0))})
@@ -879,8 +892,8 @@ def create_tmp_files(nfiles, suffix='.nc', allow_cleanup_failure=False):
879892
yield files
880893

881894

882-
@requires_netCDF4
883895
class BaseNetCDF4Test(CFEncodedDataTest):
896+
"""Tests for both netCDF4-python and h5netcdf."""
884897

885898
engine = 'netcdf4'
886899

@@ -942,6 +955,18 @@ def test_write_groups(self):
942955
with self.open(tmp_file, group='data/2') as actual2:
943956
assert_identical(data2, actual2)
944957

958+
def test_encoding_kwarg_vlen_string(self):
959+
for input_strings in [
960+
[b'foo', b'bar', b'baz'],
961+
[u'foo', u'bar', u'baz'],
962+
]:
963+
original = Dataset({'x': input_strings})
964+
expected = Dataset({'x': [u'foo', u'bar', u'baz']})
965+
kwargs = dict(encoding={'x': {'dtype': str}})
966+
with self.roundtrip(original, save_kwargs=kwargs) as actual:
967+
assert actual['x'].encoding['dtype'] is str
968+
assert_identical(actual, expected)
969+
945970
def test_roundtrip_string_with_fill_value_vlen(self):
946971
values = np.array([u'ab', u'cdef', np.nan], dtype=object)
947972
expected = Dataset({'x': ('t', values)})
@@ -1054,6 +1079,23 @@ def test_compression_encoding(self):
10541079
with self.roundtrip(expected) as actual:
10551080
assert_equal(expected, actual)
10561081

1082+
def test_encoding_kwarg_compression(self):
1083+
ds = Dataset({'x': np.arange(10.0)})
1084+
encoding = dict(dtype='f4', zlib=True, complevel=9, fletcher32=True,
1085+
chunksizes=(5,), shuffle=True)
1086+
kwargs = dict(encoding=dict(x=encoding))
1087+
1088+
with self.roundtrip(ds, save_kwargs=kwargs) as actual:
1089+
assert_equal(actual, ds)
1090+
self.assertEqual(actual.x.encoding['dtype'], 'f4')
1091+
self.assertEqual(actual.x.encoding['zlib'], True)
1092+
self.assertEqual(actual.x.encoding['complevel'], 9)
1093+
self.assertEqual(actual.x.encoding['fletcher32'], True)
1094+
self.assertEqual(actual.x.encoding['chunksizes'], (5,))
1095+
self.assertEqual(actual.x.encoding['shuffle'], True)
1096+
1097+
self.assertEqual(ds.x.encoding, {})
1098+
10571099
def test_encoding_chunksizes_unlimited(self):
10581100
# regression test for GH1225
10591101
ds = Dataset({'x': [1, 2, 3], 'y': ('x', [2, 3, 4])})
@@ -1117,7 +1159,7 @@ def test_already_open_dataset(self):
11171159
expected = Dataset({'x': ((), 42)})
11181160
assert_identical(expected, ds)
11191161

1120-
def test_variable_len_strings(self):
1162+
def test_read_variable_len_strings(self):
11211163
with create_tmp_file() as tmp_file:
11221164
values = np.array(['foo', 'bar', 'baz'], dtype=object)
11231165

@@ -1410,6 +1452,10 @@ def test_group(self):
14101452
open_kwargs={'group': group}) as actual:
14111453
assert_identical(original, actual)
14121454

1455+
def test_encoding_kwarg_fixed_width_string(self):
1456+
# not relevant for zarr, since we don't use EncodedStringCoder
1457+
pass
1458+
14131459
# TODO: someone who understand caching figure out whether chaching
14141460
# makes sense for Zarr backend
14151461
@pytest.mark.xfail(reason="Zarr caching not implemented")
@@ -1579,6 +1625,13 @@ def create_store(self):
15791625
tmp_file, mode='w', format='NETCDF3_CLASSIC') as store:
15801626
yield store
15811627

1628+
def test_encoding_kwarg_vlen_string(self):
1629+
original = Dataset({'x': [u'foo', u'bar', u'baz']})
1630+
kwargs = dict(encoding={'x': {'dtype': str}})
1631+
with raises_regex(ValueError, 'encoding dtype=str for vlen'):
1632+
with self.roundtrip(original, save_kwargs=kwargs):
1633+
pass
1634+
15821635

15831636
class NetCDF3ViaNetCDF4DataTestAutocloseTrue(NetCDF3ViaNetCDF4DataTest):
15841637
autoclose = True

xarray/tests/test_conventions.py

+7
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,14 @@ def test_roundtrip_coordinates(self):
272272
'CFEncodedInMemoryStore')
273273

274274
def test_invalid_dataarray_names_raise(self):
275+
# only relevant for on-disk file formats
275276
pass
276277

277278
def test_encoding_kwarg(self):
279+
# we haven't bothered to raise errors yet for unexpected encodings in
280+
# this test dummy
281+
pass
282+
283+
def test_encoding_kwarg_fixed_width_string(self):
284+
# CFEncodedInMemoryStore doesn't support explicit string encodings.
278285
pass

0 commit comments

Comments
 (0)