Skip to content

Commit 6d4ece4

Browse files
committed
Automatically detect grid registration from xarray data source
For xarray grids read from disk, there is an 'encoding' dictionary, with a 'source' key that gives the path to the file. Running `grdinfo` on that file can give us the grid registration. This works on the earth_relief grids. For grids that are not read from disk, we still default to assuming gridline registration ("GMT_GRID_NODE_REG").
1 parent 90b22a4 commit 6d4ece4

File tree

3 files changed

+36
-7
lines changed

3 files changed

+36
-7
lines changed

pygmt/base_plotting.py

+27-4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
use_alias,
1919
kwargs_to_strings,
2020
)
21+
from .modules import grdinfo
2122

2223

2324
class BasePlotting:
@@ -56,6 +57,22 @@ def _preprocess(self, **kwargs): # pylint: disable=no-self-use
5657
"""
5758
return kwargs
5859

60+
def autodetect_registration(self, grid):
61+
"""
62+
Function to automatically detect whether the NetCDF source of an
63+
xarray.DataArray grid uses gridline or pixel registration. Defaults to
64+
gridline registration if grdinfo cannot find a source file.
65+
"""
66+
registration = "GMT_GRID_NODE_REG" # default to gridline registration
67+
68+
try:
69+
if "Pixel node registration used" in grdinfo(grid.encoding["source"]):
70+
registration = "GMT_GRID_PIXEL_REG"
71+
except KeyError:
72+
pass
73+
74+
return registration
75+
5976
@fmt_docstring
6077
@use_alias(
6178
R="region",
@@ -282,7 +299,8 @@ def grdcontour(self, grid, **kwargs):
282299
if kind == "file":
283300
file_context = dummy_context(grid)
284301
elif kind == "grid":
285-
file_context = lib.virtualfile_from_grid(grid)
302+
registration = self.autodetect_registration(grid)
303+
file_context = lib.virtualfile_from_grid(grid, registration)
286304
else:
287305
raise GMTInvalidInput("Unrecognized data type: {}".format(type(grid)))
288306
with file_context as fname:
@@ -314,7 +332,8 @@ def grdimage(self, grid, **kwargs):
314332
if kind == "file":
315333
file_context = dummy_context(grid)
316334
elif kind == "grid":
317-
file_context = lib.virtualfile_from_grid(grid)
335+
registration = self.autodetect_registration(grid)
336+
file_context = lib.virtualfile_from_grid(grid, registration)
318337
else:
319338
raise GMTInvalidInput("Unrecognized data type: {}".format(type(grid)))
320339
with file_context as fname:
@@ -410,7 +429,8 @@ def grdview(self, grid, **kwargs):
410429
if kind == "file":
411430
file_context = dummy_context(grid)
412431
elif kind == "grid":
413-
file_context = lib.virtualfile_from_grid(grid)
432+
registration = self.autodetect_registration(grid)
433+
file_context = lib.virtualfile_from_grid(grid, registration)
414434
else:
415435
raise GMTInvalidInput(f"Unrecognized data type for grid: {type(grid)}")
416436

@@ -420,7 +440,10 @@ def grdview(self, grid, **kwargs):
420440
drapegrid = kwargs["G"]
421441
if data_kind(drapegrid) in ("file", "grid"):
422442
if data_kind(drapegrid) == "grid":
423-
drape_context = lib.virtualfile_from_grid(drapegrid)
443+
registration = self.autodetect_registration(grid)
444+
drape_context = lib.virtualfile_from_grid(
445+
drapegrid, registration
446+
)
424447
drapefile = stack.enter_context(drape_context)
425448
kwargs["G"] = drapefile
426449
else:

pygmt/clib/conversion.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ..exceptions import GMTInvalidInput
88

99

10-
def dataarray_to_matrix(grid):
10+
def dataarray_to_matrix(grid, registration="GMT_GRID_NODE_REG"):
1111
"""
1212
Transform an xarray.DataArray into a data 2D array and metadata.
1313
@@ -27,6 +27,9 @@ def dataarray_to_matrix(grid):
2727
grid : xarray.DataArray
2828
The input grid as a DataArray instance. Information is retrieved from
2929
the coordinate arrays, not from headers.
30+
registration : str
31+
Either one of "GMT_GRID_PIXEL_REG" for pixel registration, or
32+
"GMT_GRID_NODE_REG" for gridline registration [Default].
3033
3134
Returns
3235
-------
@@ -102,7 +105,10 @@ def dataarray_to_matrix(grid):
102105
dim
103106
)
104107
)
105-
region.extend([coord.min(), coord.max()])
108+
if registration == "GMT_GRID_PIXEL_REG":
109+
region.extend([coord.min() - coord_inc / 2, coord.max() + coord_inc / 2])
110+
elif registration == "GMT_GRID_NODE_REG":
111+
region.extend([coord.min(), coord.max()])
106112
inc.append(coord_inc)
107113

108114
if any([i < 0 for i in inc]): # Sort grid when there are negative increments

pygmt/clib/session.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1238,7 +1238,7 @@ def virtualfile_from_grid(self, grid, registration="GMT_GRID_NODE_REG"):
12381238
# collected and the memory freed. Creating it in this context manager
12391239
# guarantees that the copy will be around until the virtual file is
12401240
# closed. The conversion is implicit in dataarray_to_matrix.
1241-
matrix, region, inc = dataarray_to_matrix(grid)
1241+
matrix, region, inc = dataarray_to_matrix(grid, registration)
12421242
family = "GMT_IS_GRID|GMT_VIA_MATRIX"
12431243
geometry = "GMT_IS_SURFACE"
12441244
gmt_grid = self.create_data(

0 commit comments

Comments
 (0)