Skip to content

Commit 19eb3aa

Browse files
committed
Get rid of temporary files from grdtrack
1 parent bd52024 commit 19eb3aa

File tree

1 file changed

+30
-32
lines changed

1 file changed

+30
-32
lines changed

pygmt/src/grdtrack.py

+30-32
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
11
"""
22
grdtrack - Sample grids at specified (x,y) locations.
33
"""
4+
import numpy as np
45
import pandas as pd
56
from pygmt.clib import Session
67
from pygmt.exceptions import GMTInvalidInput
7-
from pygmt.helpers import (
8-
GMTTempFile,
9-
build_arg_string,
10-
fmt_docstring,
11-
kwargs_to_strings,
12-
use_alias,
13-
)
8+
from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias
149

1510
__doctest_skip__ = ["grdtrack"]
1611

@@ -43,7 +38,9 @@
4338
w="wrap",
4439
)
4540
@kwargs_to_strings(R="sequence", S="sequence", i="sequence_comma", o="sequence_comma")
46-
def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):
41+
def grdtrack(
42+
grid, points=None, newcolname=None, output_type="pandas", outfile=None, **kwargs
43+
):
4744
r"""
4845
Sample grids at specified (x,y) locations.
4946
@@ -292,29 +289,30 @@ def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):
292289
if hasattr(points, "columns") and newcolname is None:
293290
raise GMTInvalidInput("Please pass in a str to 'newcolname'")
294291

295-
with GMTTempFile(suffix=".csv") as tmpfile:
296-
with Session() as lib:
297-
with lib.virtualfile_from_data(
298-
check_kind="raster", data=grid
299-
) as grdfile, lib.virtualfile_from_data(
300-
check_kind="vector", data=points, required_data=False
301-
) as csvfile:
302-
kwargs["G"] = grdfile
303-
if outfile is None: # Output to tmpfile if outfile is not set
304-
outfile = tmpfile.name
305-
lib.call_module(
306-
module="grdtrack",
307-
args=build_arg_string(kwargs, infile=csvfile, outfile=outfile),
308-
)
292+
with Session() as lib:
293+
with lib.virtualfile_from_data(
294+
check_kind="raster", data=grid
295+
) as grdfile, lib.virtualfile_from_data(
296+
check_kind="vector", data=points, required_data=False
297+
) as csvfile, lib.virtualfile_to_gmtdataset() as outvfile:
298+
kwargs["G"] = grdfile
299+
lib.call_module(
300+
module="grdtrack",
301+
args=build_arg_string(kwargs, infile=csvfile, outfile=outvfile),
302+
)
303+
if outfile is not None:
304+
# if output_type == "file":
305+
lib.call_module("write", f"{outvfile} {outfile} -Td")
306+
return None
307+
308+
vectors = lib.gmtdataset_to_vectors(outvfile)
309+
310+
if output_type == "numpy":
311+
return np.array(vectors).T
309312

310-
# Read temporary csv output to a pandas table
311-
if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame
312-
try:
313-
column_names = points.columns.to_list() + [newcolname]
314-
result = pd.read_csv(tmpfile.name, sep="\t", names=column_names)
315-
except AttributeError: # 'str' object has no attribute 'columns'
316-
result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">")
317-
elif outfile != tmpfile.name: # return None if outfile set, output in outfile
318-
result = None
313+
if isinstance(points, pd.DataFrame):
314+
column_names = points.columns.to_list() + [newcolname]
315+
else:
316+
column_names = None
319317

320-
return result
318+
return pd.DataFrame(np.array(vectors).T, columns=column_names)

0 commit comments

Comments
 (0)