Skip to content

Refactor grdtrack to use virtualfile_from_data and improve i/o to pandas.DataFrame #1189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions pygmt/src/grdtrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
GMTTempFile,
build_arg_string,
data_kind,
dummy_context,
fmt_docstring,
use_alias,
)
Expand Down Expand Up @@ -51,8 +50,7 @@ def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs):
sampled values will be placed.

outfile : str
Required if ``points`` is a file. The file name for the output ASCII
file.
The file name for the output ASCII file.

{V}
{f}
Expand All @@ -68,21 +66,13 @@ def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs):
- None if ``outfile`` is set (track output will be stored in file set
by ``outfile``)
"""
if data_kind(points) == "matrix" and newcolname is None:
raise GMTInvalidInput("Please pass in a str to 'newcolname'")

with GMTTempFile(suffix=".csv") as tmpfile:
with Session() as lib:
# Store the pandas.DataFrame points table in virtualfile
if data_kind(points) == "matrix":
if newcolname is None:
raise GMTInvalidInput("Please pass in a str to 'newcolname'")
table_context = lib.virtualfile_from_matrix(points.values)
elif data_kind(points) == "file":
if outfile is None:
raise GMTInvalidInput("Please pass in a str to 'outfile'")
table_context = dummy_context(points)
else:
raise GMTInvalidInput(f"Unrecognized data type {type(points)}")

# Choose how data will be passed into the module
table_context = lib.virtualfile_from_data(check_kind="vector", data=points)
# Store the xarray.DataArray grid in virtualfile
grid_context = lib.virtualfile_from_data(check_kind="raster", data=grid)

Expand All @@ -100,8 +90,11 @@ def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs):

# Read temporary csv output to a pandas table
if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame
column_names = points.columns.to_list() + [newcolname]
result = pd.read_csv(tmpfile.name, sep="\t", names=column_names)
try:
column_names = points.columns.to_list() + [newcolname]
result = pd.read_csv(tmpfile.name, sep="\t", names=column_names)
except AttributeError: # 'str' object has no attribute 'columns'
result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">")
elif outfile != tmpfile.name: # return None if outfile set, output in outfile
result = None

Expand Down
13 changes: 8 additions & 5 deletions pygmt/tests/test_grdtrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_grdtrack_input_dataframe_and_dataarray(dataarray):
output = grdtrack(points=dataframe, grid=dataarray, newcolname="bathymetry")
assert isinstance(output, pd.DataFrame)
assert output.columns.to_list() == ["longitude", "latitude", "bathymetry"]
npt.assert_allclose(output.iloc[0], [-110.9536, -42.2489, -2797.394987])
npt.assert_allclose(output.iloc[0], [-110.9536, -42.2489, -2974.656296])

return output

Expand All @@ -54,7 +54,7 @@ def test_grdtrack_input_csvfile_and_dataarray(dataarray):
assert os.path.exists(path=TEMP_TRACK) # check that outfile exists at path

track = pd.read_csv(TEMP_TRACK, sep="\t", header=None, comment=">")
npt.assert_allclose(track.iloc[0], [-110.9536, -42.2489, -2797.394987])
npt.assert_allclose(track.iloc[0], [-110.9536, -42.2489, -2974.656296])
finally:
os.remove(path=TEMP_TRACK)

Expand Down Expand Up @@ -132,11 +132,14 @@ def test_grdtrack_without_newcolname_setting(dataarray):
grdtrack(points=dataframe, grid=dataarray)


def test_grdtrack_without_outfile_setting(dataarray):
def test_grdtrack_without_outfile_setting():
"""
Run grdtrack by not passing in outfile parameter setting.
"""
csvfile = which("@ridge.txt", download="c")
ncfile = which("@earth_relief_01d", download="a")

with pytest.raises(GMTInvalidInput):
grdtrack(points=csvfile, grid=dataarray)
output = grdtrack(points=csvfile, grid=ncfile)
npt.assert_allclose(output.iloc[0], [-32.2971, 37.4118, -1939.748245])

return output