Skip to content

Commit 03a5cbd

Browse files
committed
pygmt.triangulate.delaunay_triples: Add 'output_type' parameter for output in pandas/numpy/file formats
1 parent 3a507a8 commit 03a5cbd

File tree

2 files changed

+22
-59
lines changed

2 files changed

+22
-59
lines changed

pygmt/src/triangulate.py

+22-37
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
Cartesian data.
44
"""
55

6+
from typing import Literal
7+
8+
import numpy as np
69
import pandas as pd
710
from pygmt.clib import Session
811
from pygmt.helpers import (
@@ -172,10 +175,10 @@ def delaunay_triples(
172175
y=None,
173176
z=None,
174177
*,
175-
output_type="pandas",
176-
outfile=None,
178+
output_type: Literal["pandas", "numpy", "file"] = "pandas",
179+
outfile: str | None = None,
177180
**kwargs,
178-
):
181+
) -> pd.DataFrame | np.ndarray | None:
179182
"""
180183
Delaunay triangle based gridding of Cartesian data.
181184
@@ -204,16 +207,8 @@ def delaunay_triples(
204207
{table-classes}.
205208
{projection}
206209
{region}
207-
outfile : str or None
208-
The name of the output ASCII file to store the results of the
209-
histogram equalization in.
210-
output_type : str
211-
Determine the format the xyz data will be returned in [Default is
212-
``pandas``]:
213-
214-
- ``numpy`` - :class:`numpy.ndarray`
215-
- ``pandas``- :class:`pandas.DataFrame`
216-
- ``file`` - ASCII file (requires ``outfile``)
210+
{output_type}
211+
{outfile}
217212
{verbose}
218213
{binary}
219214
{nodata}
@@ -226,13 +221,13 @@ def delaunay_triples(
226221
227222
Returns
228223
-------
229-
ret : pandas.DataFrame or numpy.ndarray or None
224+
ret
230225
Return type depends on ``outfile`` and ``output_type``:
231226
232227
- None if ``outfile`` is set (output will be stored in file set by
233228
``outfile``)
234-
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if
235-
``outfile`` is not set (depends on ``output_type``)
229+
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not
230+
set (depends on ``output_type``)
236231
237232
Note
238233
----
@@ -243,25 +238,15 @@ def delaunay_triples(
243238
"""
244239
output_type = validate_output_table_type(output_type, outfile)
245240

246-
with GMTTempFile(suffix=".txt") as tmpfile:
247-
with Session() as lib:
248-
with lib.virtualfile_in(
241+
with Session() as lib:
242+
with (
243+
lib.virtualfile_in(
249244
check_kind="vector", data=data, x=x, y=y, z=z, required_z=False
250-
) as vintbl:
251-
if outfile is None:
252-
outfile = tmpfile.name
253-
lib.call_module(
254-
module="triangulate",
255-
args=build_arg_string(kwargs, infile=vintbl, outfile=outfile),
256-
)
257-
258-
if outfile == tmpfile.name:
259-
# if user did not set outfile, return pd.DataFrame
260-
result = pd.read_csv(outfile, sep="\t", header=None)
261-
elif outfile != tmpfile.name:
262-
# return None if outfile set, output in outfile
263-
result = None
264-
265-
if output_type == "numpy":
266-
result = result.to_numpy()
267-
return result
245+
) as vintbl,
246+
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
247+
):
248+
lib.call_module(
249+
module="triangulate",
250+
args=build_arg_string(kwargs, infile=vintbl, outfile=vouttbl),
251+
)
252+
return lib.virtualfile_to_dataset(output_type=output_type, vfname=vouttbl)

pygmt/tests/test_triangulate.py

-22
Original file line numberDiff line numberDiff line change
@@ -106,28 +106,6 @@ def test_delaunay_triples_ndarray_output(dataframe, expected_dataframe):
106106
np.testing.assert_allclose(actual=output, desired=expected_dataframe.to_numpy())
107107

108108

109-
def test_delaunay_triples_outfile(dataframe, expected_dataframe):
110-
"""
111-
Test triangulate.delaunay_triples with ``outfile``.
112-
"""
113-
with GMTTempFile(suffix=".txt") as tmpfile:
114-
with pytest.warns(RuntimeWarning) as record:
115-
result = triangulate.delaunay_triples(data=dataframe, outfile=tmpfile.name)
116-
assert len(record) == 1 # check that only one warning was raised
117-
assert result is None # return value is None
118-
assert Path(tmpfile.name).stat().st_size > 0
119-
temp_df = pd.read_csv(filepath_or_buffer=tmpfile.name, sep="\t", header=None)
120-
pd.testing.assert_frame_equal(left=temp_df, right=expected_dataframe)
121-
122-
123-
def test_delaunay_triples_invalid_format(dataframe):
124-
"""
125-
Test that triangulate.delaunay_triples fails with incorrect format.
126-
"""
127-
with pytest.raises(GMTInvalidInput):
128-
triangulate.delaunay_triples(data=dataframe, output_type=1)
129-
130-
131109
@pytest.mark.benchmark
132110
def test_regular_grid_no_outgrid(dataframe, expected_grid):
133111
"""

0 commit comments

Comments
 (0)