Skip to content

Commit bcfab7f

Browse files
committed
Modify grdtrack parameter names according to code review
Rename 'table' parameter to 'points', remove default 'newcolname' parameter (require it to be explicitly set) and change the output 'ret' to 'track' in docstring. Unit tests updated accordingly.
1 parent 97c2b4a commit bcfab7f

File tree

2 files changed

+34
-29
lines changed

2 files changed

+34
-29
lines changed

pygmt/sampling.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717

1818
@fmt_docstring
19-
def grdtrack(table: pd.DataFrame, grid: xr.DataArray, newcolname: str = "z_", **kwargs):
19+
def grdtrack(
20+
points: pd.DataFrame, grid: xr.DataArray, newcolname: str = None, **kwargs
21+
):
2022
"""
2123
Sample grids at specified (x,y) locations.
2224
@@ -29,7 +31,7 @@ def grdtrack(table: pd.DataFrame, grid: xr.DataArray, newcolname: str = "z_", **
2931
3032
Parameters
3133
----------
32-
table: pandas.DataFrame
34+
points: pandas.DataFrame
3335
Table with (x, y) or (lon, lat) values in the first two columns. More columns
3436
may be present.
3537
@@ -38,21 +40,26 @@ def grdtrack(table: pd.DataFrame, grid: xr.DataArray, newcolname: str = "z_", **
3840
3941
newcolname: str
4042
Name for the new column in the table where the sampled values will be placed.
41-
Defaults to "z_".
4243
4344
Returns
4445
-------
45-
ret: pandas.DataFrame
46-
Table with (x, y, ..., z_) or (lon, lat, ..., z_) values.
46+
track: pandas.DataFrame
47+
Table with (x, y, ..., newcolname) or (lon, lat, ..., newcolname) values.
4748
4849
"""
50+
51+
try:
52+
assert isinstance(newcolname, str)
53+
except AssertionError:
54+
raise GMTInvalidInput("Please pass in a str to 'newcolname'")
55+
4956
with GMTTempFile(suffix=".csv") as tmpfile:
5057
with Session() as lib:
51-
# Store the pandas.DataFrame table in virtualfile
52-
if data_kind(table) == "matrix":
53-
table_context = lib.virtualfile_from_matrix(table.values)
58+
# Store the pandas.DataFrame points table in virtualfile
59+
if data_kind(points) == "matrix":
60+
table_context = lib.virtualfile_from_matrix(points.values)
5461
else:
55-
raise GMTInvalidInput(f"Unrecognized data type {type(table)}")
62+
raise GMTInvalidInput(f"Unrecognized data type {type(points)}")
5663

5764
# Store the xarray.DataArray grid in virtualfile
5865
if data_kind(grid) == "grid":
@@ -62,7 +69,7 @@ def grdtrack(table: pd.DataFrame, grid: xr.DataArray, newcolname: str = "z_", **
6269
else:
6370
raise GMTInvalidInput(f"Unrecognized data type {type(grid)}")
6471

65-
# Run grdtrack on the temporary (csv) table and (netcdf) grid virtualfiles
72+
# Run grdtrack on the temp (csv) points table and (netcdf) grid virtualfiles
6673
with table_context as csvfile:
6774
with grid_context as grdfile:
6875
kwargs.update({"G": grdfile})
@@ -72,7 +79,7 @@ def grdtrack(table: pd.DataFrame, grid: xr.DataArray, newcolname: str = "z_", **
7279
lib.call_module(module="grdtrack", args=arg_str)
7380

7481
# Read temporary csv output to a pandas table
75-
column_names = table.columns.to_list() + [newcolname]
82+
column_names = points.columns.to_list() + [newcolname]
7683
result = pd.read_csv(tmpfile.name, sep="\t", names=column_names)
7784

7885
return result

pygmt/tests/test_grdtrack.py

+16-18
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ def test_grdtrack_input_dataframe_and_dataarray():
1919
dataframe = load_ocean_ridge_points()
2020
dataarray = load_east_pacific_rise_grid()
2121

22-
output = grdtrack(table=dataframe, grid=dataarray)
22+
output = grdtrack(points=dataframe, grid=dataarray, newcolname="bathymetry")
2323
assert isinstance(output, pd.DataFrame)
24-
assert output.columns.to_list() == ["longitude", "latitude", "z_"]
24+
assert output.columns.to_list() == ["longitude", "latitude", "bathymetry"]
2525
assert output.iloc[0].to_list() == [-110.9536, -42.2489, -2950.49576833]
2626

2727
return output
@@ -34,27 +34,28 @@ def test_grdtrack_input_dataframe_and_ncfile():
3434
dataframe = load_ocean_ridge_points()
3535
ncfile = which("@spac_33.nc", download="c")
3636

37-
output = grdtrack(table=dataframe, grid=ncfile)
37+
output = grdtrack(points=dataframe, grid=ncfile, newcolname="bathymetry")
3838
assert isinstance(output, pd.DataFrame)
39-
assert output.columns.to_list() == ["longitude", "latitude", "z_"]
39+
assert output.columns.to_list() == ["longitude", "latitude", "bathymetry"]
40+
assert output.iloc[0].to_list() == [-110.9536, -42.2489, -2950.49576833]
4041

4142
return output
4243

4344

44-
def test_grdtrack_input_wrong_kind_of_table():
45+
def test_grdtrack_wrong_kind_of_points_input():
4546
"""
46-
Run grdtrack using table input that is not a pandas.DataFrame (matrix)
47+
Run grdtrack using points input that is not a pandas.DataFrame (matrix)
4748
"""
4849
dataframe = load_ocean_ridge_points()
49-
invalid_table = dataframe.longitude.to_xarray()
50+
invalid_points = dataframe.longitude.to_xarray()
5051
dataarray = load_east_pacific_rise_grid()
5152

52-
assert data_kind(invalid_table) == "grid"
53+
assert data_kind(invalid_points) == "grid"
5354
with pytest.raises(GMTInvalidInput):
54-
grdtrack(table=invalid_table, grid=dataarray)
55+
grdtrack(points=invalid_points, grid=dataarray, newcolname="bathymetry")
5556

5657

57-
def test_grdtrack_input_wrong_kind_of_grid():
58+
def test_grdtrack_wrong_kind_of_grid_input():
5859
"""
5960
Run grdtrack using grid input that is not as xarray.DataArray (grid) or file
6061
"""
@@ -64,18 +65,15 @@ def test_grdtrack_input_wrong_kind_of_grid():
6465

6566
assert data_kind(invalid_grid) == "matrix"
6667
with pytest.raises(GMTInvalidInput):
67-
grdtrack(table=dataframe, grid=invalid_grid)
68+
grdtrack(points=dataframe, grid=invalid_grid, newcolname="bathymetry")
6869

6970

70-
def test_grdtrack_newcolname_setting():
71+
def test_grdtrack_without_newcolname_setting():
7172
"""
72-
Run grdtrack by passing in a non-default newcolname parameter setting
73+
Run grdtrack by not passing in newcolname parameter setting
7374
"""
7475
dataframe = load_ocean_ridge_points()
7576
dataarray = load_east_pacific_rise_grid()
7677

77-
output = grdtrack(table=dataframe, grid=dataarray, newcolname="bathymetry")
78-
assert output.columns.to_list() == ["longitude", "latitude", "bathymetry"]
79-
assert output.iloc[0].to_list() == [-110.9536, -42.2489, -2950.49576833]
80-
81-
return output
78+
with pytest.raises(GMTInvalidInput):
79+
grdtrack(points=dataframe, grid=dataarray)

0 commit comments

Comments
 (0)