diff --git a/asv_bench/benchmarks/groupby.py b/asv_bench/benchmarks/groupby.py index 8124883b6a0..8cd23f3947c 100644 --- a/asv_bench/benchmarks/groupby.py +++ b/asv_bench/benchmarks/groupby.py @@ -1,3 +1,5 @@ +# import flox to avoid the cost of first import +import flox.xarray # noqa import numpy as np import pandas as pd @@ -27,24 +29,24 @@ def time_init(self, ndim): @parameterized(["method", "ndim"], [("sum", "mean"), (1, 2)]) def time_agg_small_num_groups(self, method, ndim): ds = getattr(self, f"ds{ndim}d") - getattr(ds.groupby("a"), method)() + getattr(ds.groupby("a"), method)().compute() @parameterized(["method", "ndim"], [("sum", "mean"), (1, 2)]) def time_agg_large_num_groups(self, method, ndim): ds = getattr(self, f"ds{ndim}d") - getattr(ds.groupby("b"), method)() + getattr(ds.groupby("b"), method)().compute() def time_binary_op_1d(self): - self.ds1d.groupby("b") - self.ds1d_mean + (self.ds1d.groupby("b") - self.ds1d_mean).compute() def time_binary_op_2d(self): - self.ds2d.groupby("b") - self.ds2d_mean + (self.ds2d.groupby("b") - self.ds2d_mean).compute() def peakmem_binary_op_1d(self): - self.ds1d.groupby("b") - self.ds1d_mean + (self.ds1d.groupby("b") - self.ds1d_mean).compute() def peakmem_binary_op_2d(self): - self.ds2d.groupby("b") - self.ds2d_mean + (self.ds2d.groupby("b") - self.ds2d_mean).compute() class GroupByDask(GroupBy): @@ -56,8 +58,8 @@ def setup(self, *args, **kwargs): self.ds1d["c"] = self.ds1d["c"].chunk({"dim_0": 50}) self.ds2d = self.ds2d.sel(dim_0=slice(None, None, 2)) self.ds2d["c"] = self.ds2d["c"].chunk({"dim_0": 50, "z": 5}) - self.ds1d_mean = self.ds1d.groupby("b").mean() - self.ds2d_mean = self.ds2d.groupby("b").mean() + self.ds1d_mean = self.ds1d.groupby("b").mean().compute() + self.ds2d_mean = self.ds2d.groupby("b").mean().compute() class GroupByPandasDataFrame(GroupBy): @@ -88,7 +90,7 @@ def setup(self, *args, **kwargs): requires_dask() super().setup(**kwargs) self.ds1d = self.ds1d.chunk({"dim_0": 50}).to_dataframe() - self.ds1d_mean = self.ds1d.groupby("b").mean() + self.ds1d_mean = self.ds1d.groupby("b").mean().compute() def time_binary_op_2d(self): raise NotImplementedError @@ -116,12 +118,12 @@ def time_init(self, ndim): @parameterized(["method", "ndim"], [("sum", "mean"), (1, 2)]) def time_agg_small_num_groups(self, method, ndim): ds = getattr(self, f"ds{ndim}d") - getattr(ds.resample(time="3M"), method)() + getattr(ds.resample(time="3M"), method)().compute() @parameterized(["method", "ndim"], [("sum", "mean"), (1, 2)]) def time_agg_large_num_groups(self, method, ndim): ds = getattr(self, f"ds{ndim}d") - getattr(ds.resample(time="48H"), method)() + getattr(ds.resample(time="48H"), method)().compute() class ResampleDask(Resample):