Skip to content

Commit 6a88408

Browse files
seismanweiji14
andauthored
**BREAKING** pygmt.grdcut: Refactor to store output in virtualfiles for grids/images (#3115)
Co-authored-by: Wei Ji <[email protected]>
1 parent d9fb903 commit 6a88408

File tree

3 files changed

+132
-14
lines changed

3 files changed

+132
-14
lines changed

pygmt/src/grdcut.py

+36-14
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,24 @@
22
grdcut - Extract subregion from a grid.
33
"""
44

5+
from typing import Literal
6+
57
import xarray as xr
68
from pygmt.clib import Session
9+
from pygmt.exceptions import GMTInvalidInput
710
from pygmt.helpers import (
8-
GMTTempFile,
911
build_arg_list,
12+
data_kind,
1013
fmt_docstring,
1114
kwargs_to_strings,
1215
use_alias,
1316
)
14-
from pygmt.io import load_dataarray
1517

1618
__doctest_skip__ = ["grdcut"]
1719

1820

1921
@fmt_docstring
2022
@use_alias(
21-
G="outgrid",
2223
R="region",
2324
J="projection",
2425
N="extend",
@@ -28,9 +29,11 @@
2829
f="coltypes",
2930
)
3031
@kwargs_to_strings(R="sequence")
31-
def grdcut(grid, **kwargs) -> xr.DataArray | None:
32+
def grdcut(
33+
grid, kind: Literal["grid", "image"] = "grid", outgrid: str | None = None, **kwargs
34+
) -> xr.DataArray | None:
3235
r"""
33-
Extract subregion from a grid.
36+
Extract subregion from a grid or image.
3437
3538
Produce a new ``outgrid`` file which is a subregion of ``grid``. The
3639
subregion is specified with ``region``; the specified range must not exceed
@@ -48,6 +51,11 @@ def grdcut(grid, **kwargs) -> xr.DataArray | None:
4851
Parameters
4952
----------
5053
{grid}
54+
kind
55+
The raster data kind. Valid values are ``"grid"`` and ``"image"``. When the
56+
input ``grid`` is a file name, it's difficult to determine if the file is a grid
57+
or an image, so we need to specify the raster kind explicitly. The default is
58+
``"grid"``.
5159
{outgrid}
5260
{projection}
5361
{region}
@@ -100,13 +108,27 @@ def grdcut(grid, **kwargs) -> xr.DataArray | None:
100108
>>> # 12° E to 15° E and a latitude range of 21° N to 24° N
101109
>>> new_grid = pygmt.grdcut(grid=grid, region=[12, 15, 21, 24])
102110
"""
103-
with GMTTempFile(suffix=".nc") as tmpfile:
104-
with Session() as lib:
105-
with lib.virtualfile_in(check_kind="raster", data=grid) as vingrd:
106-
if (outgrid := kwargs.get("G")) is None:
107-
kwargs["G"] = outgrid = tmpfile.name # output to tmpfile
108-
lib.call_module(
109-
module="grdcut", args=build_arg_list(kwargs, infile=vingrd)
110-
)
111+
if kind not in {"grid", "image"}:
112+
msg = f"Invalid raster kind: '{kind}'. Valid values are 'grid' and 'image'."
113+
raise GMTInvalidInput(msg)
114+
115+
# Determine the output data kind based on the input data kind.
116+
match inkind := data_kind(grid):
117+
case "grid" | "image":
118+
outkind = inkind
119+
case "file":
120+
outkind = kind
121+
case _:
122+
msg = f"Unsupported data type {type(grid)}."
123+
raise GMTInvalidInput(msg)
111124

112-
return load_dataarray(outgrid) if outgrid == tmpfile.name else None
125+
with Session() as lib:
126+
with (
127+
lib.virtualfile_in(check_kind="raster", data=grid) as vingrd,
128+
lib.virtualfile_out(kind=outkind, fname=outgrid) as voutgrd,
129+
):
130+
kwargs["G"] = voutgrd
131+
lib.call_module(module="grdcut", args=build_arg_list(kwargs, infile=vingrd))
132+
return lib.virtualfile_to_raster(
133+
vfname=voutgrd, kind=outkind, outgrid=outgrid
134+
)

pygmt/tests/test_grdcut.py

+8
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,11 @@ def test_grdcut_fails():
7070
"""
7171
with pytest.raises(GMTInvalidInput):
7272
grdcut(np.arange(10).reshape((5, 2)))
73+
74+
75+
def test_grdcut_invalid_kind(grid, region):
76+
"""
77+
Check that grdcut fails with incorrect 'kind'.
78+
"""
79+
with pytest.raises(GMTInvalidInput):
80+
grdcut(grid, kind="invalid", region=region)

pygmt/tests/test_grdcut_image.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""
2+
Test pygmt.grdcut on images.
3+
"""
4+
5+
from pathlib import Path
6+
7+
import numpy as np
8+
import pytest
9+
import xarray as xr
10+
from pygmt import grdcut
11+
from pygmt.datasets import load_blue_marble
12+
from pygmt.helpers import GMTTempFile
13+
14+
try:
15+
import rioxarray
16+
17+
_HAS_RIOXARRAY = True
18+
except ImportError:
19+
_HAS_RIOXARRAY = False
20+
21+
22+
@pytest.fixture(scope="module", name="region")
23+
def fixture_region():
24+
"""
25+
Set the data region.
26+
"""
27+
return [-53, -49, -20, -17]
28+
29+
30+
@pytest.fixture(scope="module", name="expected_image")
31+
def fixture_expected_image():
32+
"""
33+
Load the expected grdcut image result.
34+
"""
35+
return xr.DataArray(
36+
data=np.array(
37+
[
38+
[[90, 93, 95, 90], [91, 90, 91, 91], [91, 90, 89, 90]],
39+
[[87, 88, 88, 89], [88, 87, 86, 85], [90, 90, 89, 88]],
40+
[[48, 49, 49, 45], [49, 48, 47, 45], [48, 47, 48, 46]],
41+
],
42+
dtype=np.uint8,
43+
),
44+
coords={
45+
"band": [1, 2, 3],
46+
"x": [-52.5, -51.5, -50.5, -49.5],
47+
"y": [-17.5, -18.5, -19.5],
48+
},
49+
dims=["band", "y", "x"],
50+
attrs={
51+
"scale_factor": 1.0,
52+
"add_offset": 0.0,
53+
},
54+
)
55+
56+
57+
@pytest.mark.benchmark
58+
def test_grdcut_image_file(region, expected_image):
59+
"""
60+
Test grdcut on an input image file.
61+
"""
62+
result = grdcut("@earth_day_01d", region=region, kind="image")
63+
xr.testing.assert_allclose(a=result, b=expected_image)
64+
65+
66+
@pytest.mark.benchmark
67+
@pytest.mark.skipif(not _HAS_RIOXARRAY, reason="rioxarray is not installed")
68+
def test_grdcut_image_dataarray(region, expected_image):
69+
"""
70+
Test grdcut on an input xarray.DataArray object.
71+
"""
72+
raster = load_blue_marble()
73+
result = grdcut(raster, region=region, kind="image")
74+
xr.testing.assert_allclose(a=result, b=expected_image)
75+
76+
77+
def test_grdcut_image_file_in_file_out(region, expected_image):
78+
"""
79+
Test grdcut on an input image file and outputs to another image file.
80+
"""
81+
with GMTTempFile(suffix=".tif") as tmp:
82+
result = grdcut("@earth_day_01d", region=region, outgrid=tmp.name)
83+
assert result is None
84+
assert Path(tmp.name).stat().st_size > 0
85+
if _HAS_RIOXARRAY:
86+
with rioxarray.open_rasterio(tmp.name) as raster:
87+
image = raster.load().drop_vars("spatial_ref")
88+
xr.testing.assert_allclose(a=image, b=expected_image)

0 commit comments

Comments
 (0)