Skip to content

Commit 844594f

Browse files
seismanweiji14
andauthored
pygmt.x2sys_cross: Refactor to use virtualfiles for output tables
Co-authored-by: Wei Ji <[email protected]>
1 parent 88eddc7 commit 844594f

File tree

2 files changed

+139
-80
lines changed

2 files changed

+139
-80
lines changed

pygmt/src/x2sys_cross.py

+58-53
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
import contextlib
66
import os
77
from pathlib import Path
8+
from typing import Any
89

910
import pandas as pd
10-
from packaging.version import Version
1111
from pygmt.clib import Session
1212
from pygmt.exceptions import GMTInvalidInput
1313
from pygmt.helpers import (
14-
GMTTempFile,
1514
build_arg_list,
1615
data_kind,
1716
fmt_docstring,
@@ -71,7 +70,9 @@ def tempfile_from_dftrack(track, suffix):
7170
Z="trackvalues",
7271
)
7372
@kwargs_to_strings(R="sequence")
74-
def x2sys_cross(tracks=None, outfile=None, **kwargs):
73+
def x2sys_cross(
74+
tracks=None, outfile: str | None = None, **kwargs
75+
) -> pd.DataFrame | None:
7576
r"""
7677
Calculate crossovers between track data files.
7778
@@ -103,10 +104,8 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs):
103104
will also be looked for via $MGD77_HOME/mgd77_paths.txt and .gmt
104105
files will be searched for via $GMT_SHAREDIR/mgg/gmtfile_paths).
105106
106-
outfile : str
107-
Optional. The file name for the output ASCII txt file to store the
108-
table in.
109-
107+
outfile
108+
The file name for the output ASCII txt file to store the table in.
110109
tag : str
111110
Specify the x2sys TAG which identifies the attributes of this data
112111
type.
@@ -183,68 +182,74 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs):
183182
184183
Returns
185184
-------
186-
crossover_errors : :class:`pandas.DataFrame` or None
187-
Table containing crossover error information.
188-
Return type depends on whether the ``outfile`` parameter is set:
189-
190-
- :class:`pandas.DataFrame` with (x, y, ..., etc) if ``outfile`` is not
191-
set
192-
- None if ``outfile`` is set (track output will be stored in the set in
193-
``outfile``)
185+
crossover_errors
186+
Table containing crossover error information. A :class:`pandas.DataFrame` object
187+
is returned if ``outfile`` is not set, otherwise ``None`` is returned and output
188+
will be stored in file set by ``outfile``.
194189
"""
195-
with Session() as lib:
196-
file_contexts = []
197-
for track in tracks:
198-
kind = data_kind(track)
199-
if kind == "file":
190+
# Determine output type based on 'outfile' parameter
191+
output_type = "pandas" if outfile is None else "file"
192+
193+
file_contexts: list[contextlib.AbstractContextManager[Any]] = []
194+
for track in tracks:
195+
match data_kind(track):
196+
case "file":
200197
file_contexts.append(contextlib.nullcontext(track))
201-
elif kind == "matrix":
198+
case "matrix":
202199
# find suffix (-E) of trackfiles used (e.g. xyz, csv, etc) from
203200
# $X2SYS_HOME/TAGNAME/TAGNAME.tag file
204-
lastline = (
205-
Path(os.environ["X2SYS_HOME"], kwargs["T"], f"{kwargs['T']}.tag")
206-
.read_text(encoding="utf8")
207-
.strip()
208-
.split("\n")[-1]
209-
) # e.g. "-Dxyz -Etsv -I1/1"
201+
tagfile = Path(
202+
os.environ["X2SYS_HOME"], kwargs["T"], f"{kwargs['T']}.tag"
203+
)
204+
# Last line is like "-Dxyz -Etsv -I1/1"
205+
lastline = tagfile.read_text(encoding="utf8").splitlines()[-1]
210206
for item in sorted(lastline.split()): # sort list alphabetically
211207
if item.startswith(("-E", "-D")): # prefer -Etsv over -Dxyz
212208
suffix = item[2:] # e.g. tsv (1st choice) or xyz (2nd choice)
213209

214210
# Save pandas.DataFrame track data to temporary file
215211
file_contexts.append(tempfile_from_dftrack(track=track, suffix=suffix))
216-
else:
212+
case _:
217213
raise GMTInvalidInput(f"Unrecognized data type: {type(track)}")
218214

219-
with GMTTempFile(suffix=".txt") as tmpfile:
215+
with Session() as lib:
216+
with lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl:
220217
with contextlib.ExitStack() as stack:
221218
fnames = [stack.enter_context(c) for c in file_contexts]
222-
if outfile is None:
223-
outfile = tmpfile.name
224219
lib.call_module(
225220
module="x2sys_cross",
226-
args=build_arg_list(kwargs, infile=fnames, outfile=outfile),
227-
)
228-
229-
# Read temporary csv output to a pandas table
230-
if outfile == tmpfile.name: # if outfile isn't set, return pd.DataFrame
231-
# Read the tab-separated ASCII table
232-
date_format_kwarg = (
233-
{"date_format": "ISO8601"}
234-
if Version(pd.__version__) >= Version("2.0.0")
235-
else {}
221+
args=build_arg_list(kwargs, infile=fnames, outfile=vouttbl),
236222
)
237-
table = pd.read_csv(
238-
tmpfile.name,
239-
sep="\t",
240-
header=2, # Column names are on 2nd row
241-
comment=">", # Skip the 3rd row with a ">"
242-
parse_dates=[2, 3], # Datetimes on 3rd and 4th column
243-
**date_format_kwarg, # Parse dates in ISO8601 format on pandas>=2
223+
result = lib.virtualfile_to_dataset(
224+
vfname=vouttbl, output_type=output_type, header=2
244225
)
245-
# Remove the "# " from "# x" in the first column
246-
table = table.rename(columns={table.columns[0]: table.columns[0][2:]})
247-
elif outfile != tmpfile.name: # if outfile is set, output in outfile only
248-
table = None
249226

250-
return table
227+
if output_type == "file":
228+
return result
229+
230+
# Convert 3rd and 4th columns to datetime/timedelta for pandas output.
231+
# These two columns have names "t_1"/"t_2" or "i_1"/"i_2".
232+
# "t_" means absolute datetimes and "i_" means dummy times.
233+
# Internally, they are all represented as double-precision numbers in GMT,
234+
# relative to TIME_EPOCH with the unit defined by TIME_UNIT.
235+
# In GMT, TIME_UNIT can be 'y' (year), 'o' (month), 'w' (week), 'd' (day),
236+
# 'h' (hour), 'm' (minute), 's' (second). Years are 365.2425 days and months
237+
# are of equal length.
238+
# pd.to_timedelta() supports unit of 'W'/'D'/'h'/'m'/'s'/'ms'/'us'/'ns'.
239+
match time_unit := lib.get_default("TIME_UNIT"):
240+
case "y":
241+
unit = "s"
242+
scale = 365.2425 * 86400.0
243+
case "o":
244+
unit = "s"
245+
scale = 365.2425 / 12.0 * 86400.0
246+
case "w" | "d" | "h" | "m" | "s":
247+
unit = time_unit.upper() if time_unit in "wd" else time_unit
248+
scale = 1.0
249+
250+
columns = result.columns[2:4]
251+
result[columns] *= scale
252+
result[columns] = result[columns].apply(pd.to_timedelta, unit=unit)
253+
if columns[0][0] == "t": # "t" or "i":
254+
result[columns] += pd.Timestamp(lib.get_default("TIME_EPOCH"))
255+
return result

pygmt/tests/test_x2sys_cross.py

+81-27
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pandas as pd
1313
import pytest
1414
from packaging.version import Version
15-
from pygmt import x2sys_cross, x2sys_init
15+
from pygmt import config, x2sys_cross, x2sys_init
1616
from pygmt.clib import __gmt_version__
1717
from pygmt.datasets import load_sample_data
1818
from pygmt.exceptions import GMTInvalidInput
@@ -52,15 +52,20 @@ def test_x2sys_cross_input_file_output_file():
5252
output = x2sys_cross(
5353
tracks=["@tut_ship.xyz"], tag=tag, coe="i", outfile=outfile
5454
)
55-
5655
assert output is None # check that output is None since outfile is set
5756
assert outfile.stat().st_size > 0 # check that outfile exists at path
58-
_ = pd.read_csv(outfile, sep="\t", header=2) # ensure ASCII text file loads ok
57+
result = pd.read_csv(outfile, sep="\t", comment=">", header=2)
58+
assert result.shape == (14374, 12) if sys.platform == "darwin" else (14338, 12)
59+
columns = list(result.columns)
60+
assert columns[:6] == ["# x", "y", "i_1", "i_2", "dist_1", "dist_2"]
61+
assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"]
62+
npt.assert_allclose(result["i_1"].min(), 45.2099, rtol=1.0e-4)
63+
npt.assert_allclose(result["i_1"].max(), 82945.9370, rtol=1.0e-4)
5964

6065

6166
@pytest.mark.usefixtures("mock_x2sys_home")
6267
@pytest.mark.xfail(
63-
condition=Version(__gmt_version__) < Version("6.5.0") or sys.platform == "darwin",
68+
condition=Version(__gmt_version__) < Version("6.5.0"),
6469
reason="Upstream bug fixed in https://github.com/GenericMappingTools/gmt/pull/8188",
6570
)
6671
def test_x2sys_cross_input_file_output_dataframe():
@@ -74,39 +79,70 @@ def test_x2sys_cross_input_file_output_dataframe():
7479
output = x2sys_cross(tracks=["@tut_ship.xyz"], tag=tag, coe="i")
7580

7681
assert isinstance(output, pd.DataFrame)
77-
assert output.shape == (14338, 12)
82+
assert output.shape == (14374, 12) if sys.platform == "darwin" else (14338, 12)
7883
columns = list(output.columns)
7984
assert columns[:6] == ["x", "y", "i_1", "i_2", "dist_1", "dist_2"]
8085
assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"]
86+
assert output["i_1"].dtype.type == np.timedelta64
87+
assert output["i_2"].dtype.type == np.timedelta64
88+
npt.assert_allclose(output["i_1"].min().total_seconds(), 45.2099, rtol=1.0e-4)
89+
npt.assert_allclose(output["i_1"].max().total_seconds(), 82945.937, rtol=1.0e-4)
8190

8291

8392
@pytest.mark.benchmark
8493
@pytest.mark.usefixtures("mock_x2sys_home")
85-
def test_x2sys_cross_input_dataframe_output_dataframe(tracks):
94+
@pytest.mark.parametrize("unit", ["s", "o", "y"])
95+
def test_x2sys_cross_input_dataframe_output_dataframe(tracks, unit):
8696
"""
8797
Run x2sys_cross by passing in one dataframe, and output internal crossovers to a
88-
pandas.DataFrame.
98+
pandas.DataFrame, checking TIME_UNIT s (second), o (month), and y (year).
8999
"""
90100
with TemporaryDirectory(prefix="X2SYS", dir=Path.cwd()) as tmpdir:
91101
tag = Path(tmpdir).name
92102
x2sys_init(tag=tag, fmtfile="xyz", force=True)
93103

94-
output = x2sys_cross(tracks=tracks, tag=tag, coe="i")
104+
with config(TIME_UNIT=unit):
105+
output = x2sys_cross(tracks=tracks, tag=tag, coe="i")
95106

96107
assert isinstance(output, pd.DataFrame)
97108
assert output.shape == (14, 12)
98109
columns = list(output.columns)
99110
assert columns[:6] == ["x", "y", "i_1", "i_2", "dist_1", "dist_2"]
100111
assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"]
101-
assert output.dtypes["i_1"].type == np.object_
102-
assert output.dtypes["i_2"].type == np.object_
112+
assert output["i_1"].dtype.type == np.timedelta64
113+
assert output["i_2"].dtype.type == np.timedelta64
114+
115+
# Scale to convert a value to second
116+
match unit:
117+
case "y":
118+
scale = 365.2425 * 86400.0
119+
case "o":
120+
scale = 365.2425 / 12.0 * 86400.0
121+
case _:
122+
scale = 1.0
123+
npt.assert_allclose(
124+
output["i_1"].min().total_seconds(), 0.9175 * scale, rtol=1.0e-4
125+
)
126+
npt.assert_allclose(
127+
output["i_1"].max().total_seconds(), 23.9996 * scale, rtol=1.0e-4
128+
)
103129

104130

105131
@pytest.mark.usefixtures("mock_x2sys_home")
106-
def test_x2sys_cross_input_two_dataframes():
132+
@pytest.mark.parametrize(
133+
("unit", "epoch"),
134+
[
135+
("s", "1970-01-01T00:00:00"),
136+
("o", "1970-01-01T00:00:00"),
137+
("y", "1970-01-01T00:00:00"),
138+
("s", "2012-03-04T05:06:07"),
139+
],
140+
)
141+
def test_x2sys_cross_input_two_dataframes(unit, epoch):
107142
"""
108143
Run x2sys_cross by passing in two pandas.DataFrame tables with a time column, and
109-
output external crossovers to a pandas.DataFrame.
144+
output external crossovers to a pandas.DataFrame, checking TIME_UNIT s (second),
145+
o (month), and y (year), and TIME_EPOCH 1970 and 2012.
110146
"""
111147
with TemporaryDirectory(prefix="X2SYS", dir=Path.cwd()) as tmpdir:
112148
tmpdir_p = Path(tmpdir)
@@ -127,15 +163,22 @@ def test_x2sys_cross_input_two_dataframes():
127163
track["time"] = pd.date_range(start=f"2020-{i}1-01", periods=10, freq="min")
128164
tracks.append(track)
129165

130-
output = x2sys_cross(tracks=tracks, tag=tag, coe="e")
166+
with config(TIME_UNIT=unit, TIME_EPOCH=epoch):
167+
output = x2sys_cross(tracks=tracks, tag=tag, coe="e")
131168

132169
assert isinstance(output, pd.DataFrame)
133170
assert output.shape == (26, 12)
134171
columns = list(output.columns)
135172
assert columns[:6] == ["x", "y", "t_1", "t_2", "dist_1", "dist_2"]
136173
assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"]
137-
assert output.dtypes["t_1"].type == np.datetime64
138-
assert output.dtypes["t_2"].type == np.datetime64
174+
assert output["t_1"].dtype.type == np.datetime64
175+
assert output["t_2"].dtype.type == np.datetime64
176+
177+
tolerance = pd.Timedelta("1ms")
178+
t1_min = pd.Timestamp("2020-01-01 00:00:10.6677")
179+
t1_max = pd.Timestamp("2020-01-01 00:08:29.8067")
180+
assert abs(output["t_1"].min() - t1_min) < tolerance
181+
assert abs(output["t_1"].max() - t1_max) < tolerance
139182

140183

141184
@pytest.mark.usefixtures("mock_x2sys_home")
@@ -159,8 +202,8 @@ def test_x2sys_cross_input_dataframe_with_nan(tracks):
159202
columns = list(output.columns)
160203
assert columns[:6] == ["x", "y", "i_1", "i_2", "dist_1", "dist_2"]
161204
assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"]
162-
assert output.dtypes["i_1"].type == np.object_
163-
assert output.dtypes["i_2"].type == np.object_
205+
assert output.dtypes["i_1"].type == np.timedelta64
206+
assert output.dtypes["i_2"].type == np.timedelta64
164207

165208

166209
@pytest.mark.usefixtures("mock_x2sys_home")
@@ -201,7 +244,7 @@ def test_x2sys_cross_invalid_tracks_input_type(tracks):
201244

202245
@pytest.mark.usefixtures("mock_x2sys_home")
203246
@pytest.mark.xfail(
204-
condition=Version(__gmt_version__) < Version("6.5.0") or sys.platform == "darwin",
247+
condition=Version(__gmt_version__) < Version("6.5.0"),
205248
reason="Upstream bug fixed in https://github.com/GenericMappingTools/gmt/pull/8188",
206249
)
207250
def test_x2sys_cross_region_interpolation_numpoints():
@@ -222,15 +265,21 @@ def test_x2sys_cross_region_interpolation_numpoints():
222265
)
223266

224267
assert isinstance(output, pd.DataFrame)
225-
assert output.shape == (3882, 12)
226-
# Check crossover errors (z_X) and mean value of observables (z_M)
227-
npt.assert_allclose(output.z_X.mean(), -138.66, rtol=1e-4)
228-
npt.assert_allclose(output.z_M.mean(), -2896.875915)
268+
if sys.platform == "darwin":
269+
assert output.shape == (3894, 12)
270+
# Check crossover errors (z_X) and mean value of observables (z_M)
271+
npt.assert_allclose(output.z_X.mean(), -138.23215, rtol=1e-4)
272+
npt.assert_allclose(output.z_M.mean(), -2897.187545, rtol=1e-4)
273+
else:
274+
assert output.shape == (3882, 12)
275+
# Check crossover errors (z_X) and mean value of observables (z_M)
276+
npt.assert_allclose(output.z_X.mean(), -138.66, rtol=1e-4)
277+
npt.assert_allclose(output.z_M.mean(), -2896.875915, rtol=1e-4)
229278

230279

231280
@pytest.mark.usefixtures("mock_x2sys_home")
232281
@pytest.mark.xfail(
233-
condition=Version(__gmt_version__) < Version("6.5.0") or sys.platform == "darwin",
282+
condition=Version(__gmt_version__) < Version("6.5.0"),
234283
reason="Upstream bug fixed in https://github.com/GenericMappingTools/gmt/pull/8188",
235284
)
236285
def test_x2sys_cross_trackvalues():
@@ -243,7 +292,12 @@ def test_x2sys_cross_trackvalues():
243292
output = x2sys_cross(tracks=["@tut_ship.xyz"], tag=tag, trackvalues=True)
244293

245294
assert isinstance(output, pd.DataFrame)
246-
assert output.shape == (14338, 12)
247-
# Check mean of track 1 values (z_1) and track 2 values (z_2)
248-
npt.assert_allclose(output.z_1.mean(), -2422.418556, rtol=1e-4)
249-
npt.assert_allclose(output.z_2.mean(), -2402.268364, rtol=1e-4)
295+
if sys.platform == "darwin":
296+
assert output.shape == (14374, 12)
297+
# Check mean of track 1 values (z_1) and track 2 values (z_2)
298+
npt.assert_allclose(output.z_1.mean(), -2422.973372, rtol=1e-4)
299+
npt.assert_allclose(output.z_2.mean(), -2402.87476, rtol=1e-4)
300+
else:
301+
assert output.shape == (14338, 12)
302+
npt.assert_allclose(output.z_1.mean(), -2422.418556, rtol=1e-4)
303+
npt.assert_allclose(output.z_2.mean(), -2402.268364, rtol=1e-4)

0 commit comments

Comments
 (0)