Skip to content

Commit e877aa7

Browse files
committed
convert to float32 to keep pydata#1840 in sync
1 parent 19ef234 commit e877aa7

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

xarray/coding/variables.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,8 @@ def _choose_float_dtype(
302302
) -> type[np.floating[Any]]:
303303
# check scale/offset first to derive dtype
304304
# see https://github.com/pydata/xarray/issues/5597#issuecomment-879561954
305-
scale_factor = mapping.get("scale_factor", False)
306-
add_offset = mapping.get("add_offset", False)
305+
scale_factor = mapping.get("scale_factor")
306+
add_offset = mapping.get("add_offset")
307307
if scale_factor or add_offset:
308308
# get the maximum itemsize from scale_factor/add_offset to determine
309309
# the needed floating point type
@@ -320,7 +320,7 @@ def _choose_float_dtype(
320320
# but a large integer offset could lead to loss of precision.
321321
# Sensitivity analysis can be tricky, so we just use a float64
322322
# if there's any offset at all - better unoptimised than wrong!
323-
if maxsize == 4 and np.issubdtype(add_offset_type, np.floating):
323+
if maxsize == 4 or not np.issubdtype(add_offset_type, np.floating):
324324
return np.float32
325325
else:
326326
return np.float64
@@ -350,12 +350,14 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
350350
if scale_factor or add_offset:
351351
dtype = _choose_float_dtype(data.dtype, attrs)
352352
data = data.astype(dtype=dtype, copy=True)
353-
if add_offset:
354-
data -= add_offset
355-
if scale_factor:
356-
data /= scale_factor
353+
if add_offset:
354+
data -= add_offset
355+
if scale_factor:
356+
data /= scale_factor
357357

358-
return Variable(dims, data, attrs, encoding, fastpath=True)
358+
return Variable(dims, data, attrs, encoding, fastpath=True)
359+
else:
360+
return variable
359361

360362
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
361363
dims, data, attrs, encoding = unpack_for_decoding(variable)

xarray/tests/test_coding.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,11 @@ def test_coder_roundtrip() -> None:
9595
assert_identical(original, roundtripped)
9696

9797

98-
@pytest.mark.parametrize("dtype", "u1 u2 i1 i2 f2 f4".split())
99-
def test_scaling_converts_to_float32(dtype) -> None:
98+
@pytest.mark.parametrize("unpacked_dtype", [np.float32, np.float64, np.int32])
99+
@pytest.mark.parametrize("packed_dtype", "u1 u2 i1 i2 f2 f4".split())
100+
def test_scaling_converts_to_float32(packed_dtype, unpacked_dtype) -> None:
100101
original = xr.Variable(
101-
("x",), np.arange(10, dtype=dtype), encoding=dict(scale_factor=10)
102+
("x",), np.arange(10, dtype=packed_dtype), encoding=dict(scale_factor=unpacked_dtype(10))
102103
)
103104
coder = variables.CFScaleOffsetCoder()
104105
encoded = coder.encode(original)

0 commit comments

Comments
 (0)