Skip to content

Commit 97c2b4a

Browse files
committed
Enable netcdf file input for grdtrack
Enable passing in netcdf file as input into grdtrack, instead of just xarray.DataArray. Also fix bug with kwargs being overwritten instead of appended to...
1 parent a6cef02 commit 97c2b4a

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

pygmt/sampling.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55
import xarray as xr
66

77
from .clib import Session
8-
from .helpers import build_arg_string, fmt_docstring, GMTTempFile, data_kind
8+
from .helpers import (
9+
build_arg_string,
10+
fmt_docstring,
11+
GMTTempFile,
12+
data_kind,
13+
dummy_context,
14+
)
915
from .exceptions import GMTInvalidInput
1016

1117

@@ -27,7 +33,7 @@ def grdtrack(table: pd.DataFrame, grid: xr.DataArray, newcolname: str = "z_", **
2733
Table with (x, y) or (lon, lat) values in the first two columns. More columns
2834
may be present.
2935
30-
grid: xarray.DataArray
36+
grid: xarray.DataArray or file (netcdf)
3137
Gridded array from which to sample values from.
3238
3339
newcolname: str
@@ -51,13 +57,15 @@ def grdtrack(table: pd.DataFrame, grid: xr.DataArray, newcolname: str = "z_", **
5157
# Store the xarray.DataArray grid in virtualfile
5258
if data_kind(grid) == "grid":
5359
grid_context = lib.virtualfile_from_grid(grid)
60+
elif data_kind(grid) == "file":
61+
grid_context = dummy_context(grid)
5462
else:
5563
raise GMTInvalidInput(f"Unrecognized data type {type(grid)}")
5664

5765
# Run grdtrack on the temporary (csv) table and (netcdf) grid virtualfiles
5866
with table_context as csvfile:
5967
with grid_context as grdfile:
60-
kwargs = {"G": grdfile}
68+
kwargs.update({"G": grdfile})
6169
arg_str = " ".join(
6270
[csvfile, build_arg_string(kwargs), "->" + tmpfile.name]
6371
)

pygmt/tests/test_grdtrack.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77

88
from .. import grdtrack
9+
from .. import which
910
from ..datasets import load_east_pacific_rise_grid, load_ocean_ridge_points
1011
from ..exceptions import GMTInvalidInput
1112
from ..helpers import data_kind
@@ -26,6 +27,20 @@ def test_grdtrack_input_dataframe_and_dataarray():
2627
return output
2728

2829

30+
def test_grdtrack_input_dataframe_and_ncfile():
31+
"""
32+
Run grdtrack by passing in a pandas.DataFrame and netcdf file as inputs
33+
"""
34+
dataframe = load_ocean_ridge_points()
35+
ncfile = which("@spac_33.nc", download="c")
36+
37+
output = grdtrack(table=dataframe, grid=ncfile)
38+
assert isinstance(output, pd.DataFrame)
39+
assert output.columns.to_list() == ["longitude", "latitude", "z_"]
40+
41+
return output
42+
43+
2944
def test_grdtrack_input_wrong_kind_of_table():
3045
"""
3146
Run grdtrack using table input that is not a pandas.DataFrame (matrix)
@@ -41,7 +56,7 @@ def test_grdtrack_input_wrong_kind_of_table():
4156

4257
def test_grdtrack_input_wrong_kind_of_grid():
4358
"""
44-
Run grdtrack using grid input that is not an xarray.DataArray (grid)
59+
Run grdtrack using grid input that is not as xarray.DataArray (grid) or file
4560
"""
4661
dataframe = load_ocean_ridge_points()
4762
dataarray = load_east_pacific_rise_grid()

0 commit comments

Comments
 (0)