Skip to content

Commit e4cc6af

Browse files
authored
Pytests for CuPy zonal stats (#658)
* adds unit-tests for cupy-zonal * fix bug when using zone_id list * fixes bug related to using zone_ids list * removes comment-out code * flake8 compatible
1 parent 84f37e3 commit e4cc6af

File tree

2 files changed

+41
-17
lines changed

2 files changed

+41
-17
lines changed

xrspatial/tests/test_zonal.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from xrspatial import trim
1414
from xrspatial import crop
1515
from xrspatial.zonal import regions
16+
from xrspatial.utils import doesnt_have_cuda
17+
1618

1719
from xrspatial.tests.general_checks import create_test_raster
1820

@@ -169,26 +171,35 @@ def check_results(backend, df_result, expected_results_dict):
169171
np.testing.assert_allclose(df_result[col], expected_results_dict[col])
170172

171173

172-
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy'])
174+
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy', 'cupy'])
173175
def test_default_stats(backend, data_zones, data_values_2d, result_default_stats):
176+
if backend == 'cupy' and doesnt_have_cuda():
177+
pytest.skip("CUDA Device not Available")
174178
df_result = stats(zones=data_zones, values=data_values_2d)
175179
check_results(backend, df_result, result_default_stats)
176180

177181

178-
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy'])
182+
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy', 'cupy'])
179183
def test_zone_ids_stats(backend, data_zones, data_values_2d, result_zone_ids_stats):
184+
if backend == 'cupy' and doesnt_have_cuda():
185+
pytest.skip("CUDA Device not Available")
180186
zone_ids, expected_result = result_zone_ids_stats
181-
df_result = stats(zones=data_zones, values=data_values_2d, zone_ids=zone_ids)
187+
df_result = stats(zones=data_zones, values=data_values_2d,
188+
zone_ids=zone_ids)
182189
check_results(backend, df_result, expected_result)
183190

184191

185-
@pytest.mark.parametrize("backend", ['numpy'])
192+
@pytest.mark.parametrize("backend", ['numpy', 'cupy'])
186193
def test_custom_stats(backend, data_zones, data_values_2d, result_custom_stats):
187-
# ---- custom stats (NumPy only) ----
194+
# ---- custom stats (NumPy and CuPy only) ----
195+
if backend == 'cupy' and doesnt_have_cuda():
196+
pytest.skip("CUDA Device not Available")
197+
188198
custom_stats = {
189199
'double_sum': _double_sum,
190200
'range': _range,
191201
}
202+
192203
nodata_values, zone_ids, expected_result = result_custom_stats
193204
df_result = stats(
194205
zones=data_zones, values=data_values_2d, stats_funcs=custom_stats,
@@ -219,7 +230,8 @@ def test_percentage_crosstab_2d(backend, data_zones, data_values_2d, result_perc
219230
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy'])
220231
def test_crosstab_3d(backend, data_zones, data_values_3d, result_crosstab_3d):
221232
layer, zone_ids, expected_result = result_crosstab_3d
222-
df_result = crosstab(zones=data_zones, values=data_values_3d, zone_ids=zone_ids, layer=layer)
233+
df_result = crosstab(zones=data_zones, values=data_values_3d,
234+
zone_ids=zone_ids, layer=layer)
223235
check_results(backend, df_result, expected_result)
224236

225237

xrspatial/zonal.py

+23-11
Original file line numberDiff line numberDiff line change
@@ -324,16 +324,30 @@ def _stats_cupy(
324324
sorted_zones = sorted_zones[filter_values]
325325

326326
# Now I need to find the unique zones, and zone breaks
327-
unique_zones, unique_index = cupy.unique(sorted_zones, return_index=True)
327+
unique_zones, unique_index, unique_counts = cupy.unique(
328+
sorted_zones, return_index=True, return_counts=True)
328329

329330
# Transfer to the host
330331
unique_index = unique_index.get()
331-
if zone_ids is None:
332-
unique_zones = unique_zones.get()
333-
else:
332+
unique_counts = unique_counts.get()
333+
unique_zones = unique_zones.get()
334+
335+
if zone_ids is not None:
336+
# We need to extract the index and element count
337+
# only for the elements in zone_ids
338+
unique_index_lst = []
339+
unique_counts_lst = []
340+
unique_zones = list(unique_zones)
341+
for z in zone_ids:
342+
try:
343+
idx = unique_zones.index(z)
344+
unique_index_lst.append(unique_index[idx])
345+
unique_counts_lst.append(unique_counts[idx])
346+
except ValueError:
347+
continue
334348
unique_zones = zone_ids
335-
# unique_zones = list(map(_to_int, unique_zones))
336-
unique_zones = np.asarray(unique_zones)
349+
unique_counts = unique_counts_lst
350+
unique_index = unique_index_lst
337351

338352
# stats columns
339353
stats_dict = {'zone': []}
@@ -347,11 +361,9 @@ def _stats_cupy(
347361
continue
348362

349363
stats_dict['zone'].append(zone_id)
364+
350365
# extract zone_values
351-
if i < len(unique_zones) - 1:
352-
zone_values = values_by_zone[unique_index[i]:unique_index[i+1]]
353-
else:
354-
zone_values = values_by_zone[unique_index[i]:]
366+
zone_values = values_by_zone[unique_index[i]:unique_index[i]+unique_counts[i]]
355367

356368
# apply stats on the zone data
357369
for j, stats in enumerate(stats_funcs):
@@ -362,7 +374,7 @@ def _stats_cupy(
362374

363375
assert(len(result.shape) == 0)
364376

365-
stats_dict[stats].append(cupy.float(result))
377+
stats_dict[stats].append(cupy.float_(result))
366378

367379
stats_df = pd.DataFrame(stats_dict)
368380
stats_df.set_index("zone")

0 commit comments

Comments
 (0)