Skip to content

Commit 809880c

Browse files
seismanyvonnefroehlichweiji14
authored
Wrap GMT's standard data type GMT_GRID for grids and refactor wrappers to use virtualfiles for output grids (#2398)
Co-authored-by: Yvonne Fröhlich <[email protected]> Co-authored-by: Wei Ji <[email protected]>
1 parent 35ed27a commit 809880c

23 files changed

+481
-339
lines changed

doc/api/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ Python objects to and from GMT virtual files:
292292
clib.Session.virtualfile_in
293293
clib.Session.virtualfile_out
294294
clib.Session.virtualfile_to_dataset
295+
clib.Session.virtualfile_to_raster
295296

296297
Low level access (these are mostly used by the :mod:`pygmt.clib` package):
297298

pygmt/clib/session.py

+70-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
import pandas as pd
17+
import xarray as xr
1718
from packaging.version import Version
1819
from pygmt.clib.conversion import (
1920
array_to_datetime,
@@ -1739,7 +1740,9 @@ def inquire_virtualfile(self, vfname: str) -> int:
17391740
return c_inquire_virtualfile(self.session_pointer, vfname.encode())
17401741

17411742
def read_virtualfile(
1742-
self, vfname: str, kind: Literal["dataset", "grid", None] = None
1743+
self,
1744+
vfname: str,
1745+
kind: Literal["dataset", "grid", "image", "cube", None] = None,
17431746
):
17441747
"""
17451748
Read data from a virtual file and optionally cast into a GMT data container.
@@ -1798,6 +1801,8 @@ def read_virtualfile(
17981801
# _GMT_DATASET).
17991802
if kind is None: # Return the ctypes void pointer
18001803
return pointer
1804+
if kind in ["image", "cube"]:
1805+
raise NotImplementedError(f"kind={kind} is not supported yet.")
18011806
dtype = {"dataset": _GMT_DATASET, "grid": _GMT_GRID}[kind]
18021807
return ctp.cast(pointer, ctp.POINTER(dtype))
18031808

@@ -1946,6 +1951,70 @@ def virtualfile_to_dataset(
19461951
return result.to_numpy()
19471952
return result # pandas.DataFrame output
19481953

1954+
def virtualfile_to_raster(
1955+
self,
1956+
vfname: str,
1957+
kind: Literal["grid", "image", "cube", None] = "grid",
1958+
outgrid: str | None = None,
1959+
) -> xr.DataArray | None:
1960+
"""
1961+
Output raster data stored in a virtual file to an :class:`xarray.DataArray`
1962+
object.
1963+
1964+
The raster data can be a grid, an image or a cube.
1965+
1966+
Parameters
1967+
----------
1968+
vfname
1969+
The virtual file name that stores the result grid/image/cube.
1970+
kind
1971+
Type of the raster data. Valid values are ``"grid"``, ``"image"``,
1972+
``"cube"`` or ``None``. If ``None``, will inquire the data type from the
1973+
virtual file name.
1974+
outgrid
1975+
Name of the output grid/image/cube. If specified, it means the raster data
1976+
was already saved into an actual file and will return ``None``.
1977+
1978+
Returns
1979+
-------
1980+
result
1981+
The result grid/image/cube. If ``outgrid`` is specified, return ``None``.
1982+
1983+
Examples
1984+
--------
1985+
>>> from pathlib import Path
1986+
>>> from pygmt.clib import Session
1987+
>>> from pygmt.helpers import GMTTempFile
1988+
>>> with Session() as lib:
1989+
... # file output
1990+
... with GMTTempFile(suffix=".nc") as tmpfile:
1991+
... outgrid = tmpfile.name
1992+
... with lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd:
1993+
... lib.call_module("read", f"@earth_relief_01d_g {voutgrd} -Tg")
1994+
... result = lib.virtualfile_to_raster(
1995+
... vfname=voutgrd, outgrid=outgrid
1996+
... )
1997+
... assert result == None
1998+
... assert Path(outgrid).stat().st_size > 0
1999+
...
2000+
... # xarray.DataArray output
2001+
... outgrid = None
2002+
... with lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd:
2003+
... lib.call_module("read", f"@earth_relief_01d_g {voutgrd} -Tg")
2004+
... result = lib.virtualfile_to_raster(vfname=voutgrd, outgrid=outgrid)
2005+
... assert isinstance(result, xr.DataArray)
2006+
"""
2007+
if outgrid is not None:
2008+
return None
2009+
if kind is None: # Inquire the data family from the virtualfile
2010+
family = self.inquire_virtualfile(vfname)
2011+
kind = { # type: ignore[assignment]
2012+
self["GMT_IS_GRID"]: "grid",
2013+
self["GMT_IS_IMAGE"]: "image",
2014+
self["GMT_IS_CUBE"]: "cube",
2015+
}[family]
2016+
return self.read_virtualfile(vfname, kind=kind).contents.to_dataarray()
2017+
19492018
def extract_region(self):
19502019
"""
19512020
Extract the WESN bounding box of the currently active figure.

pygmt/datatypes/grid.py

+191-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,197 @@
33
"""
44

55
import ctypes as ctp
6+
from typing import ClassVar
7+
8+
import numpy as np
9+
import xarray as xr
10+
from pygmt.datatypes.header import _GMT_GRID_HEADER, gmt_grdfloat
611

712

813
class _GMT_GRID(ctp.Structure): # noqa: N801
9-
pass
14+
"""
15+
GMT grid structure for holding a grid and its header.
16+
17+
This class is only meant for internal use and is not exposed to users. See the GMT
18+
source code gmt_resources.h for the original C structure definitions.
19+
20+
Examples
21+
--------
22+
>>> from pygmt.clib import Session
23+
>>> with Session() as lib:
24+
... with lib.virtualfile_out(kind="grid") as voutgrd:
25+
... lib.call_module("read", f"@static_earth_relief.nc {voutgrd} -Tg")
26+
... # Read the grid from the virtual file
27+
... grid = lib.read_virtualfile(voutgrd, kind="grid").contents
28+
... # The grid header
29+
... header = grid.header.contents
30+
... # Access the header properties
31+
... print(header.n_rows, header.n_columns, header.registration)
32+
... print(header.wesn[:], header.z_min, header.z_max, header.inc[:])
33+
... print(header.z_scale_factor, header.z_add_offset)
34+
... print(header.x_units, header.y_units, header.z_units)
35+
... print(header.title)
36+
... print(header.command)
37+
... print(header.remark)
38+
... print(header.nm, header.size, header.complex_mode)
39+
... print(header.type, header.n_bands, header.mx, header.my)
40+
... print(header.pad[:])
41+
... print(header.mem_layout, header.nan_value, header.xy_off)
42+
... # The x and y coordinates
43+
... print(grid.x[: header.n_columns])
44+
... print(grid.y[: header.n_rows])
45+
... # The data array (with paddings)
46+
... data = np.reshape(
47+
... grid.data[: header.mx * header.my], (header.my, header.mx)
48+
... )
49+
... # The data array (without paddings)
50+
... pad = header.pad[:]
51+
... data = data[pad[2] : header.my - pad[3], pad[0] : header.mx - pad[1]]
52+
... print(data)
53+
14 8 1
54+
[-55.0, -47.0, -24.0, -10.0] 190.0 981.0 [1.0, 1.0]
55+
1.0 0.0
56+
b'longitude [degrees_east]' b'latitude [degrees_north]' b'elevation (m)'
57+
b'Produced by grdcut'
58+
b'grdcut @earth_relief_01d_p -R-55/-47/-24/-10 -Gstatic_earth_relief.nc'
59+
b'Reduced by Gaussian Cartesian filtering (111.2 km fullwidth) from ...'
60+
112 216 0
61+
18 1 12 18
62+
[2, 2, 2, 2]
63+
b'' nan 0.5
64+
[-54.5, -53.5, -52.5, -51.5, -50.5, -49.5, -48.5, -47.5]
65+
[-10.5, -11.5, -12.5, -13.5, -14.5, -15.5, ..., -22.5, -23.5]
66+
[[347.5 331.5 309. 282. 190. 208. 299.5 348. ]
67+
[349. 313. 325.5 247. 191. 225. 260. 452.5]
68+
[345.5 320. 335. 292. 207.5 247. 325. 346.5]
69+
[450.5 395.5 366. 248. 250. 354.5 550. 797.5]
70+
[494.5 488.5 357. 254.5 286. 484.5 653.5 930. ]
71+
[601. 526.5 535. 299. 398.5 645. 797.5 964. ]
72+
[308. 595.5 555.5 556. 580. 770. 927. 920. ]
73+
[521.5 682.5 796. 886. 571.5 638.5 739.5 881.5]
74+
[310. 521.5 757. 570.5 538.5 524. 686.5 794. ]
75+
[561.5 539. 446.5 481.5 439.5 553. 726.5 981. ]
76+
[557. 435. 385.5 345.5 413.5 496. 519.5 833.5]
77+
[373. 367.5 349. 352.5 419.5 428. 570. 667.5]
78+
[383. 284.5 344.5 394. 491. 556.5 578.5 618.5]
79+
[347.5 344.5 386. 640.5 617. 579. 646.5 671. ]]
80+
"""
81+
82+
_fields_: ClassVar = [
83+
# Pointer to full GMT header for grid
84+
("header", ctp.POINTER(_GMT_GRID_HEADER)),
85+
# Pointer to grid data
86+
("data", ctp.POINTER(gmt_grdfloat)),
87+
# Pointer to x coordinate vector
88+
("x", ctp.POINTER(ctp.c_double)),
89+
# Pointer to y coordinate vector
90+
("y", ctp.POINTER(ctp.c_double)),
91+
# Low-level information for GMT use only
92+
("hidden", ctp.c_void_p),
93+
]
94+
95+
def to_dataarray(self) -> xr.DataArray:
96+
"""
97+
Convert a _GMT_GRID object to a :class:`xarray.DataArray` object.
98+
99+
Returns
100+
-------
101+
dataarray
102+
A :class:`xr.DataArray` object.
103+
104+
Examples
105+
--------
106+
>>> from pygmt.clib import Session
107+
>>> with Session() as lib:
108+
... with lib.virtualfile_out(kind="grid") as voutgrd:
109+
... lib.call_module("read", f"@static_earth_relief.nc {voutgrd} -Tg")
110+
... # Read the grid from the virtual file
111+
... grid = lib.read_virtualfile(voutgrd, kind="grid")
112+
... # Convert to xarray.DataArray and use it later
113+
... da = grid.contents.to_dataarray()
114+
>>> da # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
115+
<xarray.DataArray 'z' (lat: 14, lon: 8)>...
116+
array([[347.5, 344.5, 386. , 640.5, 617. , 579. , 646.5, 671. ],
117+
[383. , 284.5, 344.5, 394. , 491. , 556.5, 578.5, 618.5],
118+
[373. , 367.5, 349. , 352.5, 419.5, 428. , 570. , 667.5],
119+
[557. , 435. , 385.5, 345.5, 413.5, 496. , 519.5, 833.5],
120+
[561.5, 539. , 446.5, 481.5, 439.5, 553. , 726.5, 981. ],
121+
[310. , 521.5, 757. , 570.5, 538.5, 524. , 686.5, 794. ],
122+
[521.5, 682.5, 796. , 886. , 571.5, 638.5, 739.5, 881.5],
123+
[308. , 595.5, 555.5, 556. , 580. , 770. , 927. , 920. ],
124+
[601. , 526.5, 535. , 299. , 398.5, 645. , 797.5, 964. ],
125+
[494.5, 488.5, 357. , 254.5, 286. , 484.5, 653.5, 930. ],
126+
[450.5, 395.5, 366. , 248. , 250. , 354.5, 550. , 797.5],
127+
[345.5, 320. , 335. , 292. , 207.5, 247. , 325. , 346.5],
128+
[349. , 313. , 325.5, 247. , 191. , 225. , 260. , 452.5],
129+
[347.5, 331.5, 309. , 282. , 190. , 208. , 299.5, 348. ]])
130+
Coordinates:
131+
* lat (lat) float64... -23.5 -22.5 -21.5 -20.5 ... -12.5 -11.5 -10.5
132+
* lon (lon) float64... -54.5 -53.5 -52.5 -51.5 -50.5 -49.5 -48.5 -47.5
133+
Attributes:
134+
Conventions: CF-1.7
135+
title: Produced by grdcut
136+
history: grdcut @earth_relief_01d_p -R-55/-47/-24/-10 -Gstatic_ea...
137+
description: Reduced by Gaussian Cartesian filtering (111.2 km fullwi...
138+
long_name: elevation (m)
139+
actual_range: [190. 981.]
140+
>>> da.coords["lon"] # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
141+
<xarray.DataArray 'lon' (lon: 8)>...
142+
array([-54.5, -53.5, -52.5, -51.5, -50.5, -49.5, -48.5, -47.5])
143+
Coordinates:
144+
* lon (lon) float64... -54.5 -53.5 -52.5 -51.5 -50.5 -49.5 -48.5 -47.5
145+
Attributes:
146+
long_name: longitude
147+
units: degrees_east
148+
standard_name: longitude
149+
axis: X
150+
actual_range: [-55. -47.]
151+
>>> da.coords["lat"] # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
152+
<xarray.DataArray 'lat' (lat: 14)>...
153+
array([-23.5, -22.5, -21.5, -20.5, -19.5, -18.5, -17.5, -16.5, -15.5, -14.5,
154+
-13.5, -12.5, -11.5, -10.5])
155+
Coordinates:
156+
* lat (lat) float64... -23.5 -22.5 -21.5 -20.5 ... -12.5 -11.5 -10.5
157+
Attributes:
158+
long_name: latitude
159+
units: degrees_north
160+
standard_name: latitude
161+
axis: Y
162+
actual_range: [-24. -10.]
163+
>>> da.gmt.registration, da.gmt.gtype
164+
(1, 1)
165+
"""
166+
# The grid header
167+
header = self.header.contents
168+
169+
# Get dimensions and their attributes from the header.
170+
dims, dim_attrs = header.dims, header.dim_attrs
171+
# The coordinates, given as a tuple of the form (dims, data, attrs)
172+
coords = [
173+
(dims[0], self.y[: header.n_rows], dim_attrs[0]),
174+
(dims[1], self.x[: header.n_columns], dim_attrs[1]),
175+
]
176+
177+
# The data array without paddings
178+
pad = header.pad[:]
179+
data = np.reshape(self.data[: header.mx * header.my], (header.my, header.mx))[
180+
pad[2] : header.my - pad[3], pad[0] : header.mx - pad[1]
181+
]
182+
183+
# Create the xarray.DataArray object
184+
grid = xr.DataArray(
185+
data, coords=coords, name=header.name, attrs=header.data_attrs
186+
)
187+
188+
# Flip the coordinates and data if necessary so that coordinates are ascending.
189+
# `grid.sortby(list(grid.dims))` sometimes causes crashes.
190+
# The solution comes from https://github.com/pydata/xarray/discussions/6695.
191+
for dim in grid.dims:
192+
if grid[dim][0] > grid[dim][1]:
193+
grid = grid.isel({dim: slice(None, None, -1)})
194+
195+
# Set GMT accessors.
196+
# Must put at the end, otherwise info gets lost after certain grid operations.
197+
grid.gmt.registration = header.registration
198+
grid.gmt.gtype = header.gtype
199+
return grid

pygmt/helpers/decorators.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -267,10 +267,12 @@
267267
- ``file`` will save the result to the file specified by the ``outfile``
268268
parameter.""",
269269
"outgrid": """
270-
outgrid : str or None
271-
Name of the output netCDF grid file. For writing a specific grid
272-
file format or applying basic data operations to the output grid,
273-
see :gmt-docs:`gmt.html#grd-inout-full` for the available modifiers.""",
270+
outgrid
271+
Name of the output netCDF grid file. If not specified, will return an
272+
:class:`xarray.DataArray` object. For writing a specific grid file format or
273+
applying basic data operations to the output grid, see
274+
:gmt-docs:`gmt.html#grd-inout-full` for the available modifiers.
275+
""",
274276
"panel": r"""
275277
panel : bool, int, or list
276278
[*row,col*\|\ *index*].

pygmt/src/binstats.py

+12-20
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,13 @@
33
"""
44

55
from pygmt.clib import Session
6-
from pygmt.helpers import (
7-
GMTTempFile,
8-
build_arg_string,
9-
fmt_docstring,
10-
kwargs_to_strings,
11-
use_alias,
12-
)
13-
from pygmt.io import load_dataarray
6+
from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias
147

158

169
@fmt_docstring
1710
@use_alias(
1811
C="statistic",
1912
E="empty",
20-
G="outgrid",
2113
I="spacing",
2214
N="normalize",
2315
R="region",
@@ -31,7 +23,7 @@
3123
r="registration",
3224
)
3325
@kwargs_to_strings(I="sequence", R="sequence", i="sequence_comma")
34-
def binstats(data, **kwargs):
26+
def binstats(data, outgrid: str | None = None, **kwargs):
3527
r"""
3628
Bin spatial data and determine statistics per bin.
3729
@@ -110,13 +102,13 @@ def binstats(data, **kwargs):
110102
- None if ``outgrid`` is set (grid output will be stored in file set by
111103
``outgrid``)
112104
"""
113-
with GMTTempFile(suffix=".nc") as tmpfile:
114-
with Session() as lib:
115-
with lib.virtualfile_in(check_kind="vector", data=data) as vintbl:
116-
if (outgrid := kwargs.get("G")) is None:
117-
kwargs["G"] = outgrid = tmpfile.name # output to tmpfile
118-
lib.call_module(
119-
module="binstats", args=build_arg_string(kwargs, infile=vintbl)
120-
)
121-
122-
return load_dataarray(outgrid) if outgrid == tmpfile.name else None
105+
with Session() as lib:
106+
with (
107+
lib.virtualfile_in(check_kind="vector", data=data) as vintbl,
108+
lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd,
109+
):
110+
kwargs["G"] = voutgrd
111+
lib.call_module(
112+
module="binstats", args=build_arg_string(kwargs, infile=vintbl)
113+
)
114+
return lib.virtualfile_to_raster(vfname=voutgrd, outgrid=outgrid)

0 commit comments

Comments
 (0)