Skip to content

Commit 95fab98

Browse files
committed
pygmt.x2sys_cross: Add 'output_type' parameter for output in pandas/numpy/file formats
1 parent 62872d3 commit 95fab98

File tree

2 files changed

+39
-46
lines changed

2 files changed

+39
-46
lines changed

pygmt/src/x2sys_cross.py

+30-41
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@
55
import contextlib
66
import os
77
from pathlib import Path
8+
from typing import Literal
89

910
import pandas as pd
10-
from packaging.version import Version
1111
from pygmt.clib import Session
1212
from pygmt.exceptions import GMTInvalidInput
1313
from pygmt.helpers import (
14-
GMTTempFile,
1514
build_arg_list,
1615
data_kind,
1716
fmt_docstring,
1817
kwargs_to_strings,
1918
unique_name,
2019
use_alias,
20+
validate_output_table_type,
2121
)
2222

2323

@@ -71,7 +71,12 @@ def tempfile_from_dftrack(track, suffix):
7171
Z="trackvalues",
7272
)
7373
@kwargs_to_strings(R="sequence")
74-
def x2sys_cross(tracks=None, outfile=None, **kwargs):
74+
def x2sys_cross(
75+
tracks=None,
76+
output_type: Literal["pandas", "numpy", "file"] = "pandas",
77+
outfile: str | None = None,
78+
**kwargs,
79+
):
7580
r"""
7681
Calculate crossovers between track data files.
7782
@@ -102,11 +107,8 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs):
102107
set it will default to $GMT_SHAREDIR/x2sys]. (**Note**: MGD77 files
103108
will also be looked for via $MGD77_HOME/mgd77_paths.txt and .gmt
104109
files will be searched for via $GMT_SHAREDIR/mgg/gmtfile_paths).
105-
106-
outfile : str
107-
Optional. The file name for the output ASCII txt file to store the
108-
table in.
109-
110+
{output_type}
111+
{outfile}
110112
tag : str
111113
Specify the x2sys TAG which identifies the attributes of this data
112114
type.
@@ -183,15 +185,16 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs):
183185
184186
Returns
185187
-------
186-
crossover_errors : :class:`pandas.DataFrame` or None
187-
Table containing crossover error information.
188-
Return type depends on whether the ``outfile`` parameter is set:
189-
190-
- :class:`pandas.DataFrame` with (x, y, ..., etc) if ``outfile`` is not
191-
set
192-
- None if ``outfile`` is set (track output will be stored in the set in
193-
``outfile``)
188+
crossover_errors
189+
Table containing crossover error information. Return type depends on ``outfile``
190+
and ``output_type``:
191+
192+
- None if ``outfile`` is set (output will be stored in file set by ``outfile``)
193+
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not set
194+
(depends on ``output_type``)
194195
"""
196+
output_type = validate_output_table_type(output_type, outfile=outfile)
197+
195198
with Session() as lib:
196199
file_contexts = []
197200
for track in tracks:
@@ -216,35 +219,21 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs):
216219
else:
217220
raise GMTInvalidInput(f"Unrecognized data type: {type(track)}")
218221

219-
with GMTTempFile(suffix=".txt") as tmpfile:
222+
with lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl:
220223
with contextlib.ExitStack() as stack:
221224
fnames = [stack.enter_context(c) for c in file_contexts]
222-
if outfile is None:
223-
outfile = tmpfile.name
224225
lib.call_module(
225226
module="x2sys_cross",
226-
args=build_arg_list(kwargs, infile=fnames, outfile=outfile),
227-
)
228-
229-
# Read temporary csv output to a pandas table
230-
if outfile == tmpfile.name: # if outfile isn't set, return pd.DataFrame
231-
# Read the tab-separated ASCII table
232-
date_format_kwarg = (
233-
{"date_format": "ISO8601"}
234-
if Version(pd.__version__) >= Version("2.0.0")
235-
else {}
227+
args=build_arg_list(kwargs, infile=fnames, outfile=vouttbl),
236228
)
237-
table = pd.read_csv(
238-
tmpfile.name,
239-
sep="\t",
240-
header=2, # Column names are on 2nd row
241-
comment=">", # Skip the 3rd row with a ">"
242-
parse_dates=[2, 3], # Datetimes on 3rd and 4th column
243-
**date_format_kwarg, # Parse dates in ISO8601 format on pandas>=2
229+
result = lib.virtualfile_to_dataset(
230+
vfname=vouttbl, output_type=output_type, header=2
244231
)
245-
# Remove the "# " from "# x" in the first column
246-
table = table.rename(columns={table.columns[0]: table.columns[0][2:]})
247-
elif outfile != tmpfile.name: # if outfile is set, output in outfile only
248-
table = None
249232

250-
return table
233+
# Convert 3rd and 4th columns to datetimes.
234+
# These two columns have names "t_1"/"t_2" or "i_1"/"i_2".
235+
# "t_1"/"t_2" means they are datetimes and should be converted.
236+
# "i_1"/"i_2" means they are dummy times (i.e., floating-point values).
237+
if output_type == "pandas" and result.columns[2] == "t_1":
238+
result.iloc[:, 2:4] = result.iloc[:, 2:4].apply(pd.to_datetime)
239+
return result

pygmt/tests/test_x2sys_cross.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@ def test_x2sys_cross_input_file_output_file():
4949
x2sys_init(tag=tag, fmtfile="xyz", force=True)
5050
outfile = tmpdir_p / "tmp_coe.txt"
5151
output = x2sys_cross(
52-
tracks=["@tut_ship.xyz"], tag=tag, coe="i", outfile=outfile
52+
tracks=["@tut_ship.xyz"],
53+
tag=tag,
54+
coe="i",
55+
outfile=outfile,
56+
output_type="file",
5357
)
5458

5559
assert output is None # check that output is None since outfile is set
@@ -97,8 +101,8 @@ def test_x2sys_cross_input_dataframe_output_dataframe(tracks):
97101
columns = list(output.columns)
98102
assert columns[:6] == ["x", "y", "i_1", "i_2", "dist_1", "dist_2"]
99103
assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"]
100-
assert output.dtypes["i_1"].type == np.object_
101-
assert output.dtypes["i_2"].type == np.object_
104+
assert output.dtypes["i_1"].type == np.float64
105+
assert output.dtypes["i_2"].type == np.float64
102106

103107

104108
@pytest.mark.usefixtures("mock_x2sys_home")
@@ -158,8 +162,8 @@ def test_x2sys_cross_input_dataframe_with_nan(tracks):
158162
columns = list(output.columns)
159163
assert columns[:6] == ["x", "y", "i_1", "i_2", "dist_1", "dist_2"]
160164
assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"]
161-
assert output.dtypes["i_1"].type == np.object_
162-
assert output.dtypes["i_2"].type == np.object_
165+
assert output.dtypes["i_1"].type == np.float64
166+
assert output.dtypes["i_2"].type == np.float64
163167

164168

165169
@pytest.mark.usefixtures("mock_x2sys_home")

0 commit comments

Comments
 (0)