Skip to content

Commit 2a43385

Browse files
alexamicishoyer
andauthored
Remove the references to _file_obj outside low level code paths, change to _close (#4809)
* Move from _file_obj object to _close function * Remove all references to _close outside of low level * Fix type hints * Cleanup code style * Fix non-trivial type hint problem * Revert adding the `close` argument and add a set_close instead * Remove helper class for an easier helper function + code style * Add set_close docstring * Code style * Revert changes in _replace to keep cose as an exception See: https://github.com/pydata/xarray/pull/4809/files#r557628298 * One more bit to revert * One more bit to revert * Add What's New entry * Use set_close setter * Apply suggestions from code review Co-authored-by: Stephan Hoyer <[email protected]> * Rename user-visible argument * Sync wording in docstrings. Co-authored-by: Stephan Hoyer <[email protected]>
1 parent a2b1712 commit 2a43385

File tree

9 files changed

+53
-38
lines changed

9 files changed

+53
-38
lines changed

doc/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ Internal Changes
108108
By `Maximilian Roos <https://github.com/max-sixty>`_.
109109
- Speed up attribute style access (e.g. ``ds.somevar`` instead of ``ds["somevar"]``) and tab completion
110110
in ipython (:issue:`4741`, :pull:`4742`). By `Richard Kleijn <https://github.com/rhkleijn>`_.
111+
- Added the ``set_close`` method to ``Dataset`` and ``DataArray`` for beckends to specify how to voluntary release
112+
all resources. (:pull:`#4809`), By `Alessandro Amici <https://github.com/alexamici>`_.
111113

112114
.. _whats-new.0.16.2:
113115

xarray/backends/api.py

+9-16
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ def maybe_decode_store(store, chunks):
522522

523523
else:
524524
ds2 = ds
525-
ds2._file_obj = ds._file_obj
525+
ds2.set_close(ds._close)
526526
return ds2
527527

528528
filename_or_obj = _normalize_path(filename_or_obj)
@@ -701,7 +701,7 @@ def open_dataarray(
701701
else:
702702
(data_array,) = dataset.data_vars.values()
703703

704-
data_array._file_obj = dataset._file_obj
704+
data_array.set_close(dataset._close)
705705

706706
# Reset names if they were changed during saving
707707
# to ensure that we can 'roundtrip' perfectly
@@ -715,17 +715,6 @@ def open_dataarray(
715715
return data_array
716716

717717

718-
class _MultiFileCloser:
719-
__slots__ = ("file_objs",)
720-
721-
def __init__(self, file_objs):
722-
self.file_objs = file_objs
723-
724-
def close(self):
725-
for f in self.file_objs:
726-
f.close()
727-
728-
729718
def open_mfdataset(
730719
paths,
731720
chunks=None,
@@ -918,14 +907,14 @@ def open_mfdataset(
918907
getattr_ = getattr
919908

920909
datasets = [open_(p, **open_kwargs) for p in paths]
921-
file_objs = [getattr_(ds, "_file_obj") for ds in datasets]
910+
closers = [getattr_(ds, "_close") for ds in datasets]
922911
if preprocess is not None:
923912
datasets = [preprocess(ds) for ds in datasets]
924913

925914
if parallel:
926915
# calling compute here will return the datasets/file_objs lists,
927916
# the underlying datasets will still be stored as dask arrays
928-
datasets, file_objs = dask.compute(datasets, file_objs)
917+
datasets, closers = dask.compute(datasets, closers)
929918

930919
# Combine all datasets, closing them in case of a ValueError
931920
try:
@@ -963,7 +952,11 @@ def open_mfdataset(
963952
ds.close()
964953
raise
965954

966-
combined._file_obj = _MultiFileCloser(file_objs)
955+
def multi_file_closer():
956+
for closer in closers:
957+
closer()
958+
959+
combined.set_close(multi_file_closer)
967960

968961
# read global attributes from the attrs_file or from the first dataset
969962
if attrs_file is not None:

xarray/backends/apiv2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _dataset_from_backend_dataset(
9090
**extra_tokens,
9191
)
9292

93-
ds._file_obj = backend_ds._file_obj
93+
ds.set_close(backend_ds._close)
9494

9595
# Ensure source filename always stored in dataset object (GH issue #2550)
9696
if "source" not in ds.encoding:

xarray/backends/rasterio_.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,6 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc
361361
result = result.chunk(chunks, name_prefix=name_prefix, token=token)
362362

363363
# Make the file closeable
364-
result._file_obj = manager
364+
result.set_close(manager.close)
365365

366366
return result

xarray/backends/store.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ def open_backend_dataset_store(
1919
decode_timedelta=None,
2020
):
2121
vars, attrs = store.load()
22-
file_obj = store
2322
encoding = store.get_encoding()
2423

2524
vars, attrs, coord_names = conventions.decode_cf_variables(
@@ -36,7 +35,7 @@ def open_backend_dataset_store(
3635

3736
ds = Dataset(vars, attrs=attrs)
3837
ds = ds.set_coords(coord_names.intersection(vars))
39-
ds._file_obj = file_obj
38+
ds.set_close(store.close)
4039
ds.encoding = encoding
4140

4241
return ds

xarray/conventions.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -576,12 +576,12 @@ def decode_cf(
576576
vars = obj._variables
577577
attrs = obj.attrs
578578
extra_coords = set(obj.coords)
579-
file_obj = obj._file_obj
579+
close = obj._close
580580
encoding = obj.encoding
581581
elif isinstance(obj, AbstractDataStore):
582582
vars, attrs = obj.load()
583583
extra_coords = set()
584-
file_obj = obj
584+
close = obj.close
585585
encoding = obj.get_encoding()
586586
else:
587587
raise TypeError("can only decode Dataset or DataStore objects")
@@ -599,7 +599,7 @@ def decode_cf(
599599
)
600600
ds = Dataset(vars, attrs=attrs)
601601
ds = ds.set_coords(coord_names.union(extra_coords).intersection(vars))
602-
ds._file_obj = file_obj
602+
ds.set_close(close)
603603
ds.encoding = encoding
604604

605605
return ds

xarray/core/common.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Iterator,
1212
List,
1313
Mapping,
14+
Optional,
1415
Tuple,
1516
TypeVar,
1617
Union,
@@ -330,7 +331,9 @@ def get_squeeze_dims(
330331
class DataWithCoords(SupportsArithmetic, AttrAccessMixin):
331332
"""Shared base class for Dataset and DataArray."""
332333

333-
__slots__ = ()
334+
_close: Optional[Callable[[], None]]
335+
336+
__slots__ = ("_close",)
334337

335338
_rolling_exp_cls = RollingExp
336339

@@ -1263,11 +1266,27 @@ def where(self, cond, other=dtypes.NA, drop: bool = False):
12631266

12641267
return ops.where_method(self, cond, other)
12651268

1269+
def set_close(self, close: Optional[Callable[[], None]]) -> None:
1270+
"""Register the function that releases any resources linked to this object.
1271+
1272+
This method controls how xarray cleans up resources associated
1273+
with this object when the ``.close()`` method is called. It is mostly
1274+
intended for backend developers and it is rarely needed by regular
1275+
end-users.
1276+
1277+
Parameters
1278+
----------
1279+
close : callable
1280+
The function that when called like ``close()`` releases
1281+
any resources linked to this object.
1282+
"""
1283+
self._close = close
1284+
12661285
def close(self: Any) -> None:
1267-
"""Close any files linked to this object"""
1268-
if self._file_obj is not None:
1269-
self._file_obj.close()
1270-
self._file_obj = None
1286+
"""Release any resources linked to this object."""
1287+
if self._close is not None:
1288+
self._close()
1289+
self._close = None
12711290

12721291
def isnull(self, keep_attrs: bool = None):
12731292
"""Test each value in the array for whether it is a missing value.

xarray/core/dataarray.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -344,14 +344,15 @@ class DataArray(AbstractArray, DataWithCoords):
344344

345345
_cache: Dict[str, Any]
346346
_coords: Dict[Any, Variable]
347+
_close: Optional[Callable[[], None]]
347348
_indexes: Optional[Dict[Hashable, pd.Index]]
348349
_name: Optional[Hashable]
349350
_variable: Variable
350351

351352
__slots__ = (
352353
"_cache",
353354
"_coords",
354-
"_file_obj",
355+
"_close",
355356
"_indexes",
356357
"_name",
357358
"_variable",
@@ -421,7 +422,7 @@ def __init__(
421422
# public interface.
422423
self._indexes = indexes
423424

424-
self._file_obj = None
425+
self._close = None
425426

426427
def _replace(
427428
self,

xarray/core/dataset.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords):
636636
_coord_names: Set[Hashable]
637637
_dims: Dict[Hashable, int]
638638
_encoding: Optional[Dict[Hashable, Any]]
639+
_close: Optional[Callable[[], None]]
639640
_indexes: Optional[Dict[Hashable, pd.Index]]
640641
_variables: Dict[Hashable, Variable]
641642

@@ -645,7 +646,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords):
645646
"_coord_names",
646647
"_dims",
647648
"_encoding",
648-
"_file_obj",
649+
"_close",
649650
"_indexes",
650651
"_variables",
651652
"__weakref__",
@@ -687,7 +688,7 @@ def __init__(
687688
)
688689

689690
self._attrs = dict(attrs) if attrs is not None else None
690-
self._file_obj = None
691+
self._close = None
691692
self._encoding = None
692693
self._variables = variables
693694
self._coord_names = coord_names
@@ -703,7 +704,7 @@ def load_store(cls, store, decoder=None) -> "Dataset":
703704
if decoder:
704705
variables, attributes = decoder(variables, attributes)
705706
obj = cls(variables, attrs=attributes)
706-
obj._file_obj = store
707+
obj.set_close(store.close)
707708
return obj
708709

709710
@property
@@ -876,7 +877,7 @@ def __dask_postcompute__(self):
876877
self._attrs,
877878
self._indexes,
878879
self._encoding,
879-
self._file_obj,
880+
self._close,
880881
)
881882
return self._dask_postcompute, args
882883

@@ -896,7 +897,7 @@ def __dask_postpersist__(self):
896897
self._attrs,
897898
self._indexes,
898899
self._encoding,
899-
self._file_obj,
900+
self._close,
900901
)
901902
return self._dask_postpersist, args
902903

@@ -1007,7 +1008,7 @@ def _construct_direct(
10071008
attrs=None,
10081009
indexes=None,
10091010
encoding=None,
1010-
file_obj=None,
1011+
close=None,
10111012
):
10121013
"""Shortcut around __init__ for internal use when we want to skip
10131014
costly validation
@@ -1020,7 +1021,7 @@ def _construct_direct(
10201021
obj._dims = dims
10211022
obj._indexes = indexes
10221023
obj._attrs = attrs
1023-
obj._file_obj = file_obj
1024+
obj._close = close
10241025
obj._encoding = encoding
10251026
return obj
10261027

@@ -2122,7 +2123,7 @@ def isel(
21222123
attrs=self._attrs,
21232124
indexes=indexes,
21242125
encoding=self._encoding,
2125-
file_obj=self._file_obj,
2126+
close=self._close,
21262127
)
21272128

21282129
def _isel_fancy(

0 commit comments

Comments
 (0)