Skip to content

Commit a37fb0a

Browse files
committed
Enable ascii file input for grdtrack
Enable passing in ascii file inputs (csv, txt, etc) into grdtrack's 'points' parameter instead of just pandas.DataFrame. This requires a new 'outfile' parameter to be set. The type of 'points' input determines the type of 'track' returned, i.e. pd.DataFrame in, pd.DataFrame out; filename in, filename out. Extra unit tests created to test the various new input combinations and associated outputs.
1 parent 3776214 commit a37fb0a

File tree

3 files changed

+83
-20
lines changed

3 files changed

+83
-20
lines changed

pygmt/datasets/tutorial.py

-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
Functions to load sample data from the GMT tutorials.
33
"""
44
import pandas as pd
5-
import xarray as xr
65

76
from .. import which
87

pygmt/sampling.py

+27-18
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
GMT modules for Sampling of 1-D and 2-D Data
33
"""
44
import pandas as pd
5-
import xarray as xr
65

76
from .clib import Session
87
from .helpers import (
@@ -16,9 +15,7 @@
1615

1716

1817
@fmt_docstring
19-
def grdtrack(
20-
points: pd.DataFrame, grid: xr.DataArray, newcolname: str = None, **kwargs
21-
):
18+
def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs):
2219
"""
2320
Sample grids at specified (x,y) locations.
2421
@@ -33,33 +30,40 @@ def grdtrack(
3330
3431
Parameters
3532
----------
36-
points: pandas.DataFrame
37-
Table with (x, y) or (lon, lat) values in the first two columns. More columns
38-
may be present.
33+
points: pandas.DataFrame or file (csv, txt, etc)
34+
Either a table with (x, y) or (lon, lat) values in the first two columns,
35+
or a data file name. More columns may be present.
3936
4037
grid: xarray.DataArray or file (netcdf)
4138
Gridded array from which to sample values from.
4239
4340
newcolname: str
44-
Name for the new column in the table where the sampled values will be placed.
41+
Required if 'points' is a pandas.DataFrame. The name for the new column in the
42+
track pandas.DataFrame table where the sampled values will be placed.
43+
44+
outfile: str
45+
Required if 'points' is a file. The file name for the output ASCII file.
4546
4647
Returns
4748
-------
48-
track: pandas.DataFrame
49-
Table with (x, y, ..., newcolname) or (lon, lat, ..., newcolname) values.
49+
track: pandas.DataFrame or None
50+
Return type depends on whether the outfile parameter is set:
51+
- pandas.DataFrame table with (x, y, ..., newcolname) if outfile is not set
52+
- None if outfile is set (track output will be stored in outfile)
5053
5154
"""
5255

53-
try:
54-
assert isinstance(newcolname, str)
55-
except AssertionError:
56-
raise GMTInvalidInput("Please pass in a str to 'newcolname'")
57-
5856
with GMTTempFile(suffix=".csv") as tmpfile:
5957
with Session() as lib:
6058
# Store the pandas.DataFrame points table in virtualfile
6159
if data_kind(points) == "matrix":
60+
if newcolname is None:
61+
raise GMTInvalidInput("Please pass in a str to 'newcolname'")
6262
table_context = lib.virtualfile_from_matrix(points.values)
63+
elif data_kind(points) == "file":
64+
if outfile is None:
65+
raise GMTInvalidInput("Please pass in a str to 'outfile'")
66+
table_context = dummy_context(points)
6367
else:
6468
raise GMTInvalidInput(f"Unrecognized data type {type(points)}")
6569

@@ -75,13 +79,18 @@ def grdtrack(
7579
with table_context as csvfile:
7680
with grid_context as grdfile:
7781
kwargs.update({"G": grdfile})
82+
if outfile is None: # Output to tmpfile if outfile is not set
83+
outfile = tmpfile.name
7884
arg_str = " ".join(
79-
[csvfile, build_arg_string(kwargs), "->" + tmpfile.name]
85+
[csvfile, build_arg_string(kwargs), "->" + outfile]
8086
)
8187
lib.call_module(module="grdtrack", args=arg_str)
8288

8389
# Read temporary csv output to a pandas table
84-
column_names = points.columns.to_list() + [newcolname]
85-
result = pd.read_csv(tmpfile.name, sep="\t", names=column_names)
90+
if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame
91+
column_names = points.columns.to_list() + [newcolname]
92+
result = pd.read_csv(tmpfile.name, sep="\t", names=column_names)
93+
elif outfile != tmpfile.name: # return None if outfile set, output in outfile
94+
result = None
8695

8796
return result

pygmt/tests/test_grdtrack.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Tests for grdtrack
33
"""
4+
import os
45

56
import pandas as pd
67
import pytest
@@ -11,6 +12,9 @@
1112
from ..exceptions import GMTInvalidInput
1213
from ..helpers import data_kind
1314

15+
TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
16+
TEMP_TRACK = os.path.join(TEST_DATA_DIR, "tmp_track.txt")
17+
1418

1519
def test_grdtrack_input_dataframe_and_dataarray():
1620
"""
@@ -27,6 +31,26 @@ def test_grdtrack_input_dataframe_and_dataarray():
2731
return output
2832

2933

34+
def test_grdtrack_input_csvfile_and_dataarray():
35+
"""
36+
Run grdtrack by passing in a csvfile and xarray.DataArray as inputs
37+
"""
38+
csvfile = which("@ridge.txt", download="c")
39+
dataarray = load_earth_relief().sel(lat=slice(-49, -42), lon=slice(-118, -107))
40+
41+
try:
42+
output = grdtrack(points=csvfile, grid=dataarray, outfile=TEMP_TRACK)
43+
assert output is None # check that output is None since outfile is set
44+
assert os.path.exists(path=TEMP_TRACK) # check that outfile exists at path
45+
46+
track = pd.read_csv(TEMP_TRACK, sep="\t", header=None, comment=">")
47+
assert track.iloc[0].to_list() == [-110.9536, -42.2489, -2823.96637605]
48+
finally:
49+
os.remove(path=TEMP_TRACK)
50+
51+
return output
52+
53+
3054
def test_grdtrack_input_dataframe_and_ncfile():
3155
"""
3256
Run grdtrack by passing in a pandas.DataFrame and netcdf file as inputs
@@ -42,9 +66,29 @@ def test_grdtrack_input_dataframe_and_ncfile():
4266
return output
4367

4468

69+
def test_grdtrack_input_csvfile_and_ncfile():
70+
"""
71+
Run grdtrack by passing in a csvfile and netcdf file as inputs
72+
"""
73+
csvfile = which("@ridge.txt", download="c")
74+
ncfile = which("@earth_relief_60m", download="c")
75+
76+
try:
77+
output = grdtrack(points=csvfile, grid=ncfile, outfile=TEMP_TRACK)
78+
assert output is None # check that output is None since outfile is set
79+
assert os.path.exists(path=TEMP_TRACK) # check that outfile exists at path
80+
81+
track = pd.read_csv(TEMP_TRACK, sep="\t", header=None, comment=">")
82+
assert track.iloc[0].to_list() == [-32.2971, 37.4118, -1697.87197487]
83+
finally:
84+
os.remove(path=TEMP_TRACK)
85+
86+
return output
87+
88+
4589
def test_grdtrack_wrong_kind_of_points_input():
4690
"""
47-
Run grdtrack using points input that is not a pandas.DataFrame (matrix)
91+
Run grdtrack using points input that is not a pandas.DataFrame (matrix) or file
4892
"""
4993
dataframe = load_ocean_ridge_points()
5094
invalid_points = dataframe.longitude.to_xarray()
@@ -77,3 +121,14 @@ def test_grdtrack_without_newcolname_setting():
77121

78122
with pytest.raises(GMTInvalidInput):
79123
grdtrack(points=dataframe, grid=dataarray)
124+
125+
126+
def test_grdtrack_without_outfile_setting():
127+
"""
128+
Run grdtrack by not passing in outfile parameter setting
129+
"""
130+
csvfile = which("@ridge.txt", download="c")
131+
dataarray = load_earth_relief().sel(lat=slice(-49, -42), lon=slice(-118, -107))
132+
133+
with pytest.raises(GMTInvalidInput):
134+
grdtrack(points=csvfile, grid=dataarray)

0 commit comments

Comments
 (0)