diff --git a/pygmt/src/grdtrack.py b/pygmt/src/grdtrack.py index eaa1f0e13cc..366506a5922 100644 --- a/pygmt/src/grdtrack.py +++ b/pygmt/src/grdtrack.py @@ -8,7 +8,6 @@ GMTTempFile, build_arg_string, data_kind, - dummy_context, fmt_docstring, use_alias, ) @@ -51,8 +50,7 @@ def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs): sampled values will be placed. outfile : str - Required if ``points`` is a file. The file name for the output ASCII - file. + The file name for the output ASCII file. {V} {f} @@ -68,21 +66,13 @@ def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs): - None if ``outfile`` is set (track output will be stored in file set by ``outfile``) """ + if data_kind(points) == "matrix" and newcolname is None: + raise GMTInvalidInput("Please pass in a str to 'newcolname'") with GMTTempFile(suffix=".csv") as tmpfile: with Session() as lib: - # Store the pandas.DataFrame points table in virtualfile - if data_kind(points) == "matrix": - if newcolname is None: - raise GMTInvalidInput("Please pass in a str to 'newcolname'") - table_context = lib.virtualfile_from_matrix(points.values) - elif data_kind(points) == "file": - if outfile is None: - raise GMTInvalidInput("Please pass in a str to 'outfile'") - table_context = dummy_context(points) - else: - raise GMTInvalidInput(f"Unrecognized data type {type(points)}") - + # Choose how data will be passed into the module + table_context = lib.virtualfile_from_data(check_kind="vector", data=points) # Store the xarray.DataArray grid in virtualfile grid_context = lib.virtualfile_from_data(check_kind="raster", data=grid) @@ -100,8 +90,11 @@ def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs): # Read temporary csv output to a pandas table if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame - column_names = points.columns.to_list() + [newcolname] - result = pd.read_csv(tmpfile.name, sep="\t", names=column_names) + try: + column_names = points.columns.to_list() + [newcolname] + result = pd.read_csv(tmpfile.name, sep="\t", names=column_names) + except AttributeError: # 'str' object has no attribute 'columns' + result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">") elif outfile != tmpfile.name: # return None if outfile set, output in outfile result = None diff --git a/pygmt/tests/test_grdtrack.py b/pygmt/tests/test_grdtrack.py index 8ed74adcd47..60c728c53cd 100644 --- a/pygmt/tests/test_grdtrack.py +++ b/pygmt/tests/test_grdtrack.py @@ -36,7 +36,7 @@ def test_grdtrack_input_dataframe_and_dataarray(dataarray): output = grdtrack(points=dataframe, grid=dataarray, newcolname="bathymetry") assert isinstance(output, pd.DataFrame) assert output.columns.to_list() == ["longitude", "latitude", "bathymetry"] - npt.assert_allclose(output.iloc[0], [-110.9536, -42.2489, -2797.394987]) + npt.assert_allclose(output.iloc[0], [-110.9536, -42.2489, -2974.656296]) return output @@ -54,7 +54,7 @@ def test_grdtrack_input_csvfile_and_dataarray(dataarray): assert os.path.exists(path=TEMP_TRACK) # check that outfile exists at path track = pd.read_csv(TEMP_TRACK, sep="\t", header=None, comment=">") - npt.assert_allclose(track.iloc[0], [-110.9536, -42.2489, -2797.394987]) + npt.assert_allclose(track.iloc[0], [-110.9536, -42.2489, -2974.656296]) finally: os.remove(path=TEMP_TRACK) @@ -132,11 +132,14 @@ def test_grdtrack_without_newcolname_setting(dataarray): grdtrack(points=dataframe, grid=dataarray) -def test_grdtrack_without_outfile_setting(dataarray): +def test_grdtrack_without_outfile_setting(): """ Run grdtrack by not passing in outfile parameter setting. """ csvfile = which("@ridge.txt", download="c") + ncfile = which("@earth_relief_01d", download="a") - with pytest.raises(GMTInvalidInput): - grdtrack(points=csvfile, grid=dataarray) + output = grdtrack(points=csvfile, grid=ncfile) + npt.assert_allclose(output.iloc[0], [-32.2971, 37.4118, -1939.748245]) + + return output