|
7 | 7 | import numpy as np
|
8 | 8 | import pandas as pd
|
9 | 9 | import pytest
|
| 10 | +from pandas.core.computation.ops import UndefinedVariableError |
10 | 11 | from pandas.tseries.frequencies import to_offset
|
11 | 12 |
|
12 | 13 | import xarray as xr
|
|
39 | 40 | requires_dask,
|
40 | 41 | requires_iris,
|
41 | 42 | requires_numbagg,
|
| 43 | + requires_numexpr, |
42 | 44 | requires_scipy,
|
43 | 45 | requires_sparse,
|
44 | 46 | source_ndarray,
|
@@ -4620,6 +4622,74 @@ def test_pad_reflect(self, mode, reflect_type):
|
4620 | 4622 | assert actual.shape == (7, 4, 9)
|
4621 | 4623 | assert_identical(actual, expected)
|
4622 | 4624 |
|
| 4625 | + @pytest.mark.parametrize("parser", ["pandas", "python"]) |
| 4626 | + @pytest.mark.parametrize( |
| 4627 | + "engine", ["python", None, pytest.param("numexpr", marks=[requires_numexpr])] |
| 4628 | + ) |
| 4629 | + @pytest.mark.parametrize( |
| 4630 | + "backend", ["numpy", pytest.param("dask", marks=[requires_dask])] |
| 4631 | + ) |
| 4632 | + def test_query(self, backend, engine, parser): |
| 4633 | + """Test querying a dataset.""" |
| 4634 | + |
| 4635 | + # setup test data |
| 4636 | + np.random.seed(42) |
| 4637 | + a = np.arange(0, 10, 1) |
| 4638 | + b = np.random.randint(0, 100, size=10) |
| 4639 | + c = np.linspace(0, 1, 20) |
| 4640 | + d = np.random.choice(["foo", "bar", "baz"], size=30, replace=True).astype( |
| 4641 | + object |
| 4642 | + ) |
| 4643 | + if backend == "numpy": |
| 4644 | + aa = DataArray(data=a, dims=["x"], name="a") |
| 4645 | + bb = DataArray(data=b, dims=["x"], name="b") |
| 4646 | + cc = DataArray(data=c, dims=["y"], name="c") |
| 4647 | + dd = DataArray(data=d, dims=["z"], name="d") |
| 4648 | + |
| 4649 | + elif backend == "dask": |
| 4650 | + import dask.array as da |
| 4651 | + |
| 4652 | + aa = DataArray(data=da.from_array(a, chunks=3), dims=["x"], name="a") |
| 4653 | + bb = DataArray(data=da.from_array(b, chunks=3), dims=["x"], name="b") |
| 4654 | + cc = DataArray(data=da.from_array(c, chunks=7), dims=["y"], name="c") |
| 4655 | + dd = DataArray(data=da.from_array(d, chunks=12), dims=["z"], name="d") |
| 4656 | + |
| 4657 | + # query single dim, single variable |
| 4658 | + actual = aa.query(x="a > 5", engine=engine, parser=parser) |
| 4659 | + expect = aa.isel(x=(a > 5)) |
| 4660 | + assert_identical(expect, actual) |
| 4661 | + |
| 4662 | + # query single dim, single variable, via dict |
| 4663 | + actual = aa.query(dict(x="a > 5"), engine=engine, parser=parser) |
| 4664 | + expect = aa.isel(dict(x=(a > 5))) |
| 4665 | + assert_identical(expect, actual) |
| 4666 | + |
| 4667 | + # query single dim, single variable |
| 4668 | + actual = bb.query(x="b > 50", engine=engine, parser=parser) |
| 4669 | + expect = bb.isel(x=(b > 50)) |
| 4670 | + assert_identical(expect, actual) |
| 4671 | + |
| 4672 | + # query single dim, single variable |
| 4673 | + actual = cc.query(y="c < .5", engine=engine, parser=parser) |
| 4674 | + expect = cc.isel(y=(c < 0.5)) |
| 4675 | + assert_identical(expect, actual) |
| 4676 | + |
| 4677 | + # query single dim, single string variable |
| 4678 | + if parser == "pandas": |
| 4679 | + # N.B., this query currently only works with the pandas parser |
| 4680 | + # xref https://github.com/pandas-dev/pandas/issues/40436 |
| 4681 | + actual = dd.query(z='d == "bar"', engine=engine, parser=parser) |
| 4682 | + expect = dd.isel(z=(d == "bar")) |
| 4683 | + assert_identical(expect, actual) |
| 4684 | + |
| 4685 | + # test error handling |
| 4686 | + with pytest.raises(ValueError): |
| 4687 | + aa.query("a > 5") # must be dict or kwargs |
| 4688 | + with pytest.raises(ValueError): |
| 4689 | + aa.query(x=(a > 5)) # must be query string |
| 4690 | + with pytest.raises(UndefinedVariableError): |
| 4691 | + aa.query(x="spam > 50") # name not present |
| 4692 | + |
4623 | 4693 |
|
4624 | 4694 | class TestReduce:
|
4625 | 4695 | @pytest.fixture(autouse=True)
|
|
0 commit comments