Skip to content

Commit 84ce139

Browse files
committed
pygmt.grdcut: Support both grid and image output
1 parent 48e6505 commit 84ce139

File tree

4 files changed

+52
-18
lines changed

4 files changed

+52
-18
lines changed

pygmt/clib/session.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,15 @@
2525
vectors_to_arrays,
2626
)
2727
from pygmt.clib.loading import load_libgmt
28-
from pygmt.datatypes import _GMT_DATASET, _GMT_GRID
28+
from pygmt.datatypes import _GMT_DATASET, _GMT_GRID, _GMT_IMAGE
2929
from pygmt.exceptions import (
3030
GMTCLibError,
3131
GMTCLibNoSessionError,
3232
GMTInvalidInput,
3333
GMTVersionError,
3434
)
3535
from pygmt.helpers import (
36+
GMTTempFile,
3637
data_kind,
3738
fmt_docstring,
3839
tempfile_from_geojson,
@@ -1649,7 +1650,9 @@ def virtualfile_in( # noqa: PLR0912
16491650

16501651
@contextlib.contextmanager
16511652
def virtualfile_out(
1652-
self, kind: Literal["dataset", "grid"] = "dataset", fname: str | None = None
1653+
self,
1654+
kind: Literal["dataset", "grid", "image"] = "dataset",
1655+
fname: str | None = None,
16531656
):
16541657
r"""
16551658
Create a virtual file or an actual file for storing output data.
@@ -1706,8 +1709,11 @@ def virtualfile_out(
17061709
family, geometry = {
17071710
"dataset": ("GMT_IS_DATASET", "GMT_IS_PLP"),
17081711
"grid": ("GMT_IS_GRID", "GMT_IS_SURFACE"),
1712+
"image": ("GMT_IS_IMAGE", "GMT_IS_SURFACE"),
17091713
}[kind]
1710-
with self.open_virtualfile(family, geometry, "GMT_OUT", None) as vfile:
1714+
with self.open_virtualfile(
1715+
family, geometry, "GMT_OUT|GMT_IS_REFERENCE", None
1716+
) as vfile:
17111717
yield vfile
17121718

17131719
def inquire_virtualfile(self, vfname: str) -> int:
@@ -1801,9 +1807,13 @@ def read_virtualfile(
18011807
# _GMT_DATASET).
18021808
if kind is None: # Return the ctypes void pointer
18031809
return pointer
1804-
if kind in ["image", "cube"]:
1810+
if kind == "cube":
18051811
raise NotImplementedError(f"kind={kind} is not supported yet.")
1806-
dtype = {"dataset": _GMT_DATASET, "grid": _GMT_GRID}[kind]
1812+
dtype = {
1813+
"dataset": _GMT_DATASET,
1814+
"grid": _GMT_GRID,
1815+
"image": _GMT_IMAGE,
1816+
}[kind]
18071817
return ctp.cast(pointer, ctp.POINTER(dtype))
18081818

18091819
def virtualfile_to_dataset(
@@ -2013,6 +2023,10 @@ def virtualfile_to_raster(
20132023
self["GMT_IS_IMAGE"]: "image",
20142024
self["GMT_IS_CUBE"]: "cube",
20152025
}[family]
2026+
if kind == "image":
2027+
with GMTTempFile(suffix=".tif") as tmpfile:
2028+
self.call_module("write", f"{vfname} {tmpfile.name} -Ti")
2029+
return xr.load_dataarray(tmpfile.name)
20162030
return self.read_virtualfile(vfname, kind=kind).contents.to_dataarray()
20172031

20182032
def extract_region(self):

pygmt/datatypes/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44

55
from pygmt.datatypes.dataset import _GMT_DATASET
66
from pygmt.datatypes.grid import _GMT_GRID
7+
from pygmt.datatypes.image import _GMT_IMAGE

pygmt/datatypes/image.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""
2+
Wrapper for the GMT_GRID data type.
3+
"""
4+
5+
import ctypes as ctp
6+
7+
8+
class _GMT_IMAGE(ctp.Structure): # noqa: N801
9+
pass

pygmt/src/grdcut.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,19 @@
44

55
from pygmt.clib import Session
66
from pygmt.helpers import (
7-
GMTTempFile,
87
build_arg_string,
8+
data_kind,
99
fmt_docstring,
1010
kwargs_to_strings,
1111
use_alias,
1212
)
13-
from pygmt.io import load_dataarray
13+
from pygmt.src.which import which
1414

1515
__doctest_skip__ = ["grdcut"]
1616

1717

1818
@fmt_docstring
1919
@use_alias(
20-
G="outgrid",
2120
R="region",
2221
J="projection",
2322
N="extend",
@@ -27,7 +26,7 @@
2726
f="coltypes",
2827
)
2928
@kwargs_to_strings(R="sequence")
30-
def grdcut(grid, **kwargs):
29+
def grdcut(grid, outgrid: str | None = None, **kwargs):
3130
r"""
3231
Extract subregion from a grid.
3332
@@ -99,13 +98,24 @@ def grdcut(grid, **kwargs):
9998
>>> # 12° E to 15° E and a latitude range of 21° N to 24° N
10099
>>> new_grid = pygmt.grdcut(grid=grid, region=[12, 15, 21, 24])
101100
"""
102-
with GMTTempFile(suffix=".nc") as tmpfile:
103-
with Session() as lib:
104-
with lib.virtualfile_in(check_kind="raster", data=grid) as vingrd:
105-
if (outgrid := kwargs.get("G")) is None:
106-
kwargs["G"] = outgrid = tmpfile.name # output to tmpfile
107-
lib.call_module(
108-
module="grdcut", args=build_arg_string(kwargs, infile=vingrd)
109-
)
101+
inkind = data_kind(grid)
102+
outkind = "image" if inkind == "image" else "grid"
103+
if inkind == "file":
104+
realpath = which(str(grid))
105+
if isinstance(realpath, list):
106+
realpath = realpath[0]
107+
if realpath.endswith(".tif"):
108+
outkind = "image"
110109

111-
return load_dataarray(outgrid) if outgrid == tmpfile.name else None
110+
with Session() as lib:
111+
with (
112+
lib.virtualfile_in(check_kind="raster", data=grid) as vingrd,
113+
lib.virtualfile_out(kind=outkind, fname=outgrid) as voutgrd,
114+
):
115+
kwargs["G"] = voutgrd
116+
lib.call_module(
117+
module="grdcut", args=build_arg_string(kwargs, infile=vingrd)
118+
)
119+
return lib.virtualfile_to_raster(
120+
outgrid=outgrid, kind=outkind, vfname=voutgrd
121+
)

0 commit comments

Comments
 (0)