Skip to content

Commit 295afc0

Browse files
committed
pygmt.x2sys_cross: Add 'output_type' parameter for output in pandas/numpy/file formats
1 parent 62872d3 commit 295afc0

File tree

1 file changed

+18
-29
lines changed

1 file changed

+18
-29
lines changed

pygmt/src/x2sys_cross.py

+18-29
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@
55
import contextlib
66
import os
77
from pathlib import Path
8+
from typing import Literal
89

910
import pandas as pd
10-
from packaging.version import Version
1111
from pygmt.clib import Session
1212
from pygmt.exceptions import GMTInvalidInput
1313
from pygmt.helpers import (
14-
GMTTempFile,
1514
build_arg_list,
1615
data_kind,
1716
fmt_docstring,
1817
kwargs_to_strings,
1918
unique_name,
2019
use_alias,
20+
validate_output_table_type,
2121
)
2222

2323

@@ -71,7 +71,12 @@ def tempfile_from_dftrack(track, suffix):
7171
Z="trackvalues",
7272
)
7373
@kwargs_to_strings(R="sequence")
74-
def x2sys_cross(tracks=None, outfile=None, **kwargs):
74+
def x2sys_cross(
75+
tracks=None,
76+
output_type: Literal["pandas", "numpy", "file"] = "pandas",
77+
outfile: str | None = None,
78+
**kwargs,
79+
):
7580
r"""
7681
Calculate crossovers between track data files.
7782
@@ -192,6 +197,8 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs):
192197
- None if ``outfile`` is set (track output will be stored in the set in
193198
``outfile``)
194199
"""
200+
output_type = validate_output_table_type(output_type, outfile=outfile)
201+
195202
with Session() as lib:
196203
file_contexts = []
197204
for track in tracks:
@@ -216,35 +223,17 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs):
216223
else:
217224
raise GMTInvalidInput(f"Unrecognized data type: {type(track)}")
218225

219-
with GMTTempFile(suffix=".txt") as tmpfile:
226+
with lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl:
220227
with contextlib.ExitStack() as stack:
221228
fnames = [stack.enter_context(c) for c in file_contexts]
222-
if outfile is None:
223-
outfile = tmpfile.name
224229
lib.call_module(
225230
module="x2sys_cross",
226-
args=build_arg_list(kwargs, infile=fnames, outfile=outfile),
227-
)
228-
229-
# Read temporary csv output to a pandas table
230-
if outfile == tmpfile.name: # if outfile isn't set, return pd.DataFrame
231-
# Read the tab-separated ASCII table
232-
date_format_kwarg = (
233-
{"date_format": "ISO8601"}
234-
if Version(pd.__version__) >= Version("2.0.0")
235-
else {}
231+
args=build_arg_list(kwargs, infile=fnames, outfile=vouttbl),
236232
)
237-
table = pd.read_csv(
238-
tmpfile.name,
239-
sep="\t",
240-
header=2, # Column names are on 2nd row
241-
comment=">", # Skip the 3rd row with a ">"
242-
parse_dates=[2, 3], # Datetimes on 3rd and 4th column
243-
**date_format_kwarg, # Parse dates in ISO8601 format on pandas>=2
233+
result = lib.virtualfile_to_dataset(
234+
vfname=vouttbl, output_type=output_type, header=2
244235
)
245-
# Remove the "# " from "# x" in the first column
246-
table = table.rename(columns={table.columns[0]: table.columns[0][2:]})
247-
elif outfile != tmpfile.name: # if outfile is set, output in outfile only
248-
table = None
249-
250-
return table
236+
if output_type != "file":
237+
# Datetimes on 3rd and 4th columns
238+
result.iloc[:, 2:4] = result.iloc[:, 2:4].apply(pd.to_datetime)
239+
return result

0 commit comments

Comments
 (0)