Skip to content

Commit 52043bc

Browse files
Add initial cupy tests (#4214)
* Add initial cupy tests * Linting * Docstrings
1 parent 7bf9df9 commit 52043bc

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

xarray/tests/test_cupy.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
5+
import xarray as xr
6+
7+
cp = pytest.importorskip("cupy")
8+
9+
10+
@pytest.fixture
11+
def toy_weather_data():
12+
"""Construct the example DataSet from the Toy weather data example.
13+
14+
http://xarray.pydata.org/en/stable/examples/weather-data.html
15+
16+
Here we construct the DataSet exactly as shown in the example and then
17+
convert the numpy arrays to cupy.
18+
19+
"""
20+
np.random.seed(123)
21+
times = pd.date_range("2000-01-01", "2001-12-31", name="time")
22+
annual_cycle = np.sin(2 * np.pi * (times.dayofyear.values / 365.25 - 0.28))
23+
24+
base = 10 + 15 * annual_cycle.reshape(-1, 1)
25+
tmin_values = base + 3 * np.random.randn(annual_cycle.size, 3)
26+
tmax_values = base + 10 + 3 * np.random.randn(annual_cycle.size, 3)
27+
28+
ds = xr.Dataset(
29+
{
30+
"tmin": (("time", "location"), tmin_values),
31+
"tmax": (("time", "location"), tmax_values),
32+
},
33+
{"time": times, "location": ["IA", "IN", "IL"]},
34+
)
35+
36+
ds.tmax.data = cp.asarray(ds.tmax.data)
37+
ds.tmin.data = cp.asarray(ds.tmin.data)
38+
39+
return ds
40+
41+
42+
def test_cupy_import():
43+
"""Check the import worked."""
44+
assert cp
45+
46+
47+
def test_check_data_stays_on_gpu(toy_weather_data):
48+
"""Perform some operations and check the data stays on the GPU."""
49+
freeze = (toy_weather_data["tmin"] <= 0).groupby("time.month").mean("time")
50+
assert isinstance(freeze.data, cp.core.core.ndarray)

0 commit comments

Comments
 (0)