Skip to content

Commit c054314

Browse files
committed
Fix DataArray.__dask_scheduler__ to point to dask.threaded.get
Previously this erroneously pointed to an optimize function, likely a copy-paste error. For testing this also redirects the .compute methods to use the dask.compute function directly *if* dask.__version__ >= '0.16.0'.
1 parent c2b205f commit c054314

File tree

3 files changed

+31
-11
lines changed

3 files changed

+31
-11
lines changed

xarray/core/dataarray.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def __dask_optimize__(self):
594594

595595
@property
596596
def __dask_scheduler__(self):
597-
return self._to_temp_dataset().__dask_optimize__
597+
return self._to_temp_dataset().__dask_scheduler__
598598

599599
def __dask_postcompute__(self):
600600
func, args = self._to_temp_dataset().__dask_postcompute__()
@@ -654,8 +654,12 @@ def compute(self, **kwargs):
654654
--------
655655
dask.array.compute
656656
"""
657-
new = self.copy(deep=False)
658-
return new.load(**kwargs)
657+
import dask
658+
if dask.__version__ >= '0.16.0':
659+
return dask.compute(self, **kwargs)[0]
660+
else:
661+
new = self.copy(deep=False)
662+
return new.load(**kwargs)
659663

660664
def persist(self, **kwargs):
661665
""" Trigger computation in constituent dask arrays
@@ -673,8 +677,12 @@ def persist(self, **kwargs):
673677
--------
674678
dask.persist
675679
"""
676-
ds = self._to_temp_dataset().persist(**kwargs)
677-
return self._from_temp_dataset(ds)
680+
import dask
681+
if dask.__version__ >= '0.16.0':
682+
return dask.persist(self, **kwargs)[0]
683+
else:
684+
ds = self._to_temp_dataset().persist(**kwargs)
685+
return self._from_temp_dataset(ds)
678686

679687
def copy(self, deep=True):
680688
"""Returns a copy of this array.

xarray/core/dataset.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -590,8 +590,12 @@ def compute(self, **kwargs):
590590
--------
591591
dask.array.compute
592592
"""
593-
new = self.copy(deep=False)
594-
return new.load(**kwargs)
593+
import dask
594+
if dask.__version__ >= '0.16.0':
595+
return dask.compute(self, **kwargs)[0]
596+
else:
597+
new = self.copy(deep=False)
598+
return new.load(**kwargs)
595599

596600
def _persist_inplace(self, **kwargs):
597601
""" Persist all Dask arrays in memory """
@@ -627,8 +631,12 @@ def persist(self, **kwargs):
627631
--------
628632
dask.persist
629633
"""
630-
new = self.copy(deep=False)
631-
return new._persist_inplace(**kwargs)
634+
import dask
635+
if dask.__version__ >= '0.16.0':
636+
return dask.persist(self, **kwargs)[0]
637+
else:
638+
new = self.copy(deep=False)
639+
return new._persist_inplace(**kwargs)
632640

633641
@classmethod
634642
def _construct_direct(cls, variables, coord_names, dims=None, attrs=None,

xarray/core/variable.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,12 @@ def compute(self, **kwargs):
352352
--------
353353
dask.array.compute
354354
"""
355-
new = self.copy(deep=False)
356-
return new.load(**kwargs)
355+
import dask
356+
if dask.__version__ >= '0.16.0':
357+
return dask.compute(self, **kwargs)[0]
358+
else:
359+
new = self.copy(deep=False)
360+
return new.load(**kwargs)
357361

358362
def __dask_graph__(self):
359363
if isinstance(self._data, dask_array_type):

0 commit comments

Comments
 (0)