Skip to content

Commit 5280524

Browse files
committed
pygmt.x2sys_cross: Add 'output_type' parameter for output in pandas/numpy/file formats
1 parent 62872d3 commit 5280524

File tree

2 files changed

+30
-33
lines changed

2 files changed

+30
-33
lines changed

pygmt/src/x2sys_cross.py

+21-28
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@
55
import contextlib
66
import os
77
from pathlib import Path
8+
from typing import Literal
89

910
import pandas as pd
10-
from packaging.version import Version
1111
from pygmt.clib import Session
1212
from pygmt.exceptions import GMTInvalidInput
1313
from pygmt.helpers import (
14-
GMTTempFile,
1514
build_arg_list,
1615
data_kind,
1716
fmt_docstring,
1817
kwargs_to_strings,
1918
unique_name,
2019
use_alias,
20+
validate_output_table_type,
2121
)
2222

2323

@@ -71,7 +71,12 @@ def tempfile_from_dftrack(track, suffix):
7171
Z="trackvalues",
7272
)
7373
@kwargs_to_strings(R="sequence")
74-
def x2sys_cross(tracks=None, outfile=None, **kwargs):
74+
def x2sys_cross(
75+
tracks=None,
76+
output_type: Literal["pandas", "numpy", "file"] = "pandas",
77+
outfile: str | None = None,
78+
**kwargs,
79+
):
7580
r"""
7681
Calculate crossovers between track data files.
7782
@@ -192,6 +197,8 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs):
192197
- None if ``outfile`` is set (track output will be stored in the set in
193198
``outfile``)
194199
"""
200+
output_type = validate_output_table_type(output_type, outfile=outfile)
201+
195202
with Session() as lib:
196203
file_contexts = []
197204
for track in tracks:
@@ -216,35 +223,21 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs):
216223
else:
217224
raise GMTInvalidInput(f"Unrecognized data type: {type(track)}")
218225

219-
with GMTTempFile(suffix=".txt") as tmpfile:
226+
with lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl:
220227
with contextlib.ExitStack() as stack:
221228
fnames = [stack.enter_context(c) for c in file_contexts]
222-
if outfile is None:
223-
outfile = tmpfile.name
224229
lib.call_module(
225230
module="x2sys_cross",
226-
args=build_arg_list(kwargs, infile=fnames, outfile=outfile),
227-
)
228-
229-
# Read temporary csv output to a pandas table
230-
if outfile == tmpfile.name: # if outfile isn't set, return pd.DataFrame
231-
# Read the tab-separated ASCII table
232-
date_format_kwarg = (
233-
{"date_format": "ISO8601"}
234-
if Version(pd.__version__) >= Version("2.0.0")
235-
else {}
231+
args=build_arg_list(kwargs, infile=fnames, outfile=vouttbl),
236232
)
237-
table = pd.read_csv(
238-
tmpfile.name,
239-
sep="\t",
240-
header=2, # Column names are on 2nd row
241-
comment=">", # Skip the 3rd row with a ">"
242-
parse_dates=[2, 3], # Datetimes on 3rd and 4th column
243-
**date_format_kwarg, # Parse dates in ISO8601 format on pandas>=2
233+
result = lib.virtualfile_to_dataset(
234+
vfname=vouttbl, output_type=output_type, header=2
244235
)
245-
# Remove the "# " from "# x" in the first column
246-
table = table.rename(columns={table.columns[0]: table.columns[0][2:]})
247-
elif outfile != tmpfile.name: # if outfile is set, output in outfile only
248-
table = None
249236

250-
return table
237+
# Convert 3rd and 4th columns to datetimes.
238+
# These two columns have names "t_1"/"t_2" or "i_1"/"i_2".
239+
# "t_1"/"t_2" means they are datetimes and should be converted.
240+
# "i_1"/"i_2" means they are dummy times (i.e., floating-point values).
241+
if output_type != "file" and result.columns[2] == "t_1":
242+
result.iloc[:, 2:4] = result.iloc[:, 2:4].apply(pd.to_datetime)
243+
return result

pygmt/tests/test_x2sys_cross.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@ def test_x2sys_cross_input_file_output_file():
4949
x2sys_init(tag=tag, fmtfile="xyz", force=True)
5050
outfile = tmpdir_p / "tmp_coe.txt"
5151
output = x2sys_cross(
52-
tracks=["@tut_ship.xyz"], tag=tag, coe="i", outfile=outfile
52+
tracks=["@tut_ship.xyz"],
53+
tag=tag,
54+
coe="i",
55+
outfile=outfile,
56+
output_type="file",
5357
)
5458

5559
assert output is None # check that output is None since outfile is set
@@ -97,8 +101,8 @@ def test_x2sys_cross_input_dataframe_output_dataframe(tracks):
97101
columns = list(output.columns)
98102
assert columns[:6] == ["x", "y", "i_1", "i_2", "dist_1", "dist_2"]
99103
assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"]
100-
assert output.dtypes["i_1"].type == np.object_
101-
assert output.dtypes["i_2"].type == np.object_
104+
assert output.dtypes["i_1"].type == np.float64
105+
assert output.dtypes["i_2"].type == np.float64
102106

103107

104108
@pytest.mark.usefixtures("mock_x2sys_home")
@@ -158,8 +162,8 @@ def test_x2sys_cross_input_dataframe_with_nan(tracks):
158162
columns = list(output.columns)
159163
assert columns[:6] == ["x", "y", "i_1", "i_2", "dist_1", "dist_2"]
160164
assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"]
161-
assert output.dtypes["i_1"].type == np.object_
162-
assert output.dtypes["i_2"].type == np.object_
165+
assert output.dtypes["i_1"].type == np.float64
166+
assert output.dtypes["i_2"].type == np.float64
163167

164168

165169
@pytest.mark.usefixtures("mock_x2sys_home")

0 commit comments

Comments
 (0)