diff --git a/pygmt/src/x2sys_cross.py b/pygmt/src/x2sys_cross.py index eadd20dcfb2..af79cfee852 100644 --- a/pygmt/src/x2sys_cross.py +++ b/pygmt/src/x2sys_cross.py @@ -5,13 +5,12 @@ import contextlib import os from pathlib import Path +from typing import Any import pandas as pd -from packaging.version import Version from pygmt.clib import Session from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import ( - GMTTempFile, build_arg_list, data_kind, fmt_docstring, @@ -71,7 +70,9 @@ def tempfile_from_dftrack(track, suffix): Z="trackvalues", ) @kwargs_to_strings(R="sequence") -def x2sys_cross(tracks=None, outfile=None, **kwargs): +def x2sys_cross( + tracks=None, outfile: str | None = None, **kwargs +) -> pd.DataFrame | None: r""" Calculate crossovers between track data files. @@ -103,10 +104,8 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs): will also be looked for via $MGD77_HOME/mgd77_paths.txt and .gmt files will be searched for via $GMT_SHAREDIR/mgg/gmtfile_paths). - outfile : str - Optional. The file name for the output ASCII txt file to store the - table in. - + outfile + The file name for the output ASCII txt file to store the table in. tag : str Specify the x2sys TAG which identifies the attributes of this data type. @@ -183,68 +182,74 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs): Returns ------- - crossover_errors : :class:`pandas.DataFrame` or None - Table containing crossover error information. - Return type depends on whether the ``outfile`` parameter is set: - - - :class:`pandas.DataFrame` with (x, y, ..., etc) if ``outfile`` is not - set - - None if ``outfile`` is set (track output will be stored in the set in - ``outfile``) + crossover_errors + Table containing crossover error information. A :class:`pandas.DataFrame` object + is returned if ``outfile`` is not set, otherwise ``None`` is returned and output + will be stored in file set by ``outfile``. """ - with Session() as lib: - file_contexts = [] - for track in tracks: - kind = data_kind(track) - if kind == "file": + # Determine output type based on 'outfile' parameter + output_type = "pandas" if outfile is None else "file" + + file_contexts: list[contextlib.AbstractContextManager[Any]] = [] + for track in tracks: + match data_kind(track): + case "file": file_contexts.append(contextlib.nullcontext(track)) - elif kind == "matrix": + case "matrix": # find suffix (-E) of trackfiles used (e.g. xyz, csv, etc) from # $X2SYS_HOME/TAGNAME/TAGNAME.tag file - lastline = ( - Path(os.environ["X2SYS_HOME"], kwargs["T"], f"{kwargs['T']}.tag") - .read_text(encoding="utf8") - .strip() - .split("\n")[-1] - ) # e.g. "-Dxyz -Etsv -I1/1" + tagfile = Path( + os.environ["X2SYS_HOME"], kwargs["T"], f"{kwargs['T']}.tag" + ) + # Last line is like "-Dxyz -Etsv -I1/1" + lastline = tagfile.read_text(encoding="utf8").splitlines()[-1] for item in sorted(lastline.split()): # sort list alphabetically if item.startswith(("-E", "-D")): # prefer -Etsv over -Dxyz suffix = item[2:] # e.g. tsv (1st choice) or xyz (2nd choice) # Save pandas.DataFrame track data to temporary file file_contexts.append(tempfile_from_dftrack(track=track, suffix=suffix)) - else: + case _: raise GMTInvalidInput(f"Unrecognized data type: {type(track)}") - with GMTTempFile(suffix=".txt") as tmpfile: + with Session() as lib: + with lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl: with contextlib.ExitStack() as stack: fnames = [stack.enter_context(c) for c in file_contexts] - if outfile is None: - outfile = tmpfile.name lib.call_module( module="x2sys_cross", - args=build_arg_list(kwargs, infile=fnames, outfile=outfile), - ) - - # Read temporary csv output to a pandas table - if outfile == tmpfile.name: # if outfile isn't set, return pd.DataFrame - # Read the tab-separated ASCII table - date_format_kwarg = ( - {"date_format": "ISO8601"} - if Version(pd.__version__) >= Version("2.0.0") - else {} + args=build_arg_list(kwargs, infile=fnames, outfile=vouttbl), ) - table = pd.read_csv( - tmpfile.name, - sep="\t", - header=2, # Column names are on 2nd row - comment=">", # Skip the 3rd row with a ">" - parse_dates=[2, 3], # Datetimes on 3rd and 4th column - **date_format_kwarg, # Parse dates in ISO8601 format on pandas>=2 + result = lib.virtualfile_to_dataset( + vfname=vouttbl, output_type=output_type, header=2 ) - # Remove the "# " from "# x" in the first column - table = table.rename(columns={table.columns[0]: table.columns[0][2:]}) - elif outfile != tmpfile.name: # if outfile is set, output in outfile only - table = None - return table + if output_type == "file": + return result + + # Convert 3rd and 4th columns to datetime/timedelta for pandas output. + # These two columns have names "t_1"/"t_2" or "i_1"/"i_2". + # "t_" means absolute datetimes and "i_" means dummy times. + # Internally, they are all represented as double-precision numbers in GMT, + # relative to TIME_EPOCH with the unit defined by TIME_UNIT. + # In GMT, TIME_UNIT can be 'y' (year), 'o' (month), 'w' (week), 'd' (day), + # 'h' (hour), 'm' (minute), 's' (second). Years are 365.2425 days and months + # are of equal length. + # pd.to_timedelta() supports unit of 'W'/'D'/'h'/'m'/'s'/'ms'/'us'/'ns'. + match time_unit := lib.get_default("TIME_UNIT"): + case "y": + unit = "s" + scale = 365.2425 * 86400.0 + case "o": + unit = "s" + scale = 365.2425 / 12.0 * 86400.0 + case "w" | "d" | "h" | "m" | "s": + unit = time_unit.upper() if time_unit in "wd" else time_unit + scale = 1.0 + + columns = result.columns[2:4] + result[columns] *= scale + result[columns] = result[columns].apply(pd.to_timedelta, unit=unit) + if columns[0][0] == "t": # "t" or "i": + result[columns] += pd.Timestamp(lib.get_default("TIME_EPOCH")) + return result diff --git a/pygmt/tests/test_x2sys_cross.py b/pygmt/tests/test_x2sys_cross.py index bae686efe27..09f424d1a42 100644 --- a/pygmt/tests/test_x2sys_cross.py +++ b/pygmt/tests/test_x2sys_cross.py @@ -12,7 +12,7 @@ import pandas as pd import pytest from packaging.version import Version -from pygmt import x2sys_cross, x2sys_init +from pygmt import config, x2sys_cross, x2sys_init from pygmt.clib import __gmt_version__ from pygmt.datasets import load_sample_data from pygmt.exceptions import GMTInvalidInput @@ -52,15 +52,20 @@ def test_x2sys_cross_input_file_output_file(): output = x2sys_cross( tracks=["@tut_ship.xyz"], tag=tag, coe="i", outfile=outfile ) - assert output is None # check that output is None since outfile is set assert outfile.stat().st_size > 0 # check that outfile exists at path - _ = pd.read_csv(outfile, sep="\t", header=2) # ensure ASCII text file loads ok + result = pd.read_csv(outfile, sep="\t", comment=">", header=2) + assert result.shape == (14374, 12) if sys.platform == "darwin" else (14338, 12) + columns = list(result.columns) + assert columns[:6] == ["# x", "y", "i_1", "i_2", "dist_1", "dist_2"] + assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"] + npt.assert_allclose(result["i_1"].min(), 45.2099, rtol=1.0e-4) + npt.assert_allclose(result["i_1"].max(), 82945.9370, rtol=1.0e-4) @pytest.mark.usefixtures("mock_x2sys_home") @pytest.mark.xfail( - condition=Version(__gmt_version__) < Version("6.5.0") or sys.platform == "darwin", + condition=Version(__gmt_version__) < Version("6.5.0"), reason="Upstream bug fixed in https://github.com/GenericMappingTools/gmt/pull/8188", ) def test_x2sys_cross_input_file_output_dataframe(): @@ -74,39 +79,70 @@ def test_x2sys_cross_input_file_output_dataframe(): output = x2sys_cross(tracks=["@tut_ship.xyz"], tag=tag, coe="i") assert isinstance(output, pd.DataFrame) - assert output.shape == (14338, 12) + assert output.shape == (14374, 12) if sys.platform == "darwin" else (14338, 12) columns = list(output.columns) assert columns[:6] == ["x", "y", "i_1", "i_2", "dist_1", "dist_2"] assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"] + assert output["i_1"].dtype.type == np.timedelta64 + assert output["i_2"].dtype.type == np.timedelta64 + npt.assert_allclose(output["i_1"].min().total_seconds(), 45.2099, rtol=1.0e-4) + npt.assert_allclose(output["i_1"].max().total_seconds(), 82945.937, rtol=1.0e-4) @pytest.mark.benchmark @pytest.mark.usefixtures("mock_x2sys_home") -def test_x2sys_cross_input_dataframe_output_dataframe(tracks): +@pytest.mark.parametrize("unit", ["s", "o", "y"]) +def test_x2sys_cross_input_dataframe_output_dataframe(tracks, unit): """ Run x2sys_cross by passing in one dataframe, and output internal crossovers to a - pandas.DataFrame. + pandas.DataFrame, checking TIME_UNIT s (second), o (month), and y (year). """ with TemporaryDirectory(prefix="X2SYS", dir=Path.cwd()) as tmpdir: tag = Path(tmpdir).name x2sys_init(tag=tag, fmtfile="xyz", force=True) - output = x2sys_cross(tracks=tracks, tag=tag, coe="i") + with config(TIME_UNIT=unit): + output = x2sys_cross(tracks=tracks, tag=tag, coe="i") assert isinstance(output, pd.DataFrame) assert output.shape == (14, 12) columns = list(output.columns) assert columns[:6] == ["x", "y", "i_1", "i_2", "dist_1", "dist_2"] assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"] - assert output.dtypes["i_1"].type == np.object_ - assert output.dtypes["i_2"].type == np.object_ + assert output["i_1"].dtype.type == np.timedelta64 + assert output["i_2"].dtype.type == np.timedelta64 + + # Scale to convert a value to second + match unit: + case "y": + scale = 365.2425 * 86400.0 + case "o": + scale = 365.2425 / 12.0 * 86400.0 + case _: + scale = 1.0 + npt.assert_allclose( + output["i_1"].min().total_seconds(), 0.9175 * scale, rtol=1.0e-4 + ) + npt.assert_allclose( + output["i_1"].max().total_seconds(), 23.9996 * scale, rtol=1.0e-4 + ) @pytest.mark.usefixtures("mock_x2sys_home") -def test_x2sys_cross_input_two_dataframes(): +@pytest.mark.parametrize( + ("unit", "epoch"), + [ + ("s", "1970-01-01T00:00:00"), + ("o", "1970-01-01T00:00:00"), + ("y", "1970-01-01T00:00:00"), + ("s", "2012-03-04T05:06:07"), + ], +) +def test_x2sys_cross_input_two_dataframes(unit, epoch): """ Run x2sys_cross by passing in two pandas.DataFrame tables with a time column, and - output external crossovers to a pandas.DataFrame. + output external crossovers to a pandas.DataFrame, checking TIME_UNIT s (second), + o (month), and y (year), and TIME_EPOCH 1970 and 2012. """ with TemporaryDirectory(prefix="X2SYS", dir=Path.cwd()) as tmpdir: tmpdir_p = Path(tmpdir) @@ -127,15 +163,22 @@ def test_x2sys_cross_input_two_dataframes(): track["time"] = pd.date_range(start=f"2020-{i}1-01", periods=10, freq="min") tracks.append(track) - output = x2sys_cross(tracks=tracks, tag=tag, coe="e") + with config(TIME_UNIT=unit, TIME_EPOCH=epoch): + output = x2sys_cross(tracks=tracks, tag=tag, coe="e") assert isinstance(output, pd.DataFrame) assert output.shape == (26, 12) columns = list(output.columns) assert columns[:6] == ["x", "y", "t_1", "t_2", "dist_1", "dist_2"] assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"] - assert output.dtypes["t_1"].type == np.datetime64 - assert output.dtypes["t_2"].type == np.datetime64 + assert output["t_1"].dtype.type == np.datetime64 + assert output["t_2"].dtype.type == np.datetime64 + + tolerance = pd.Timedelta("1ms") + t1_min = pd.Timestamp("2020-01-01 00:00:10.6677") + t1_max = pd.Timestamp("2020-01-01 00:08:29.8067") + assert abs(output["t_1"].min() - t1_min) < tolerance + assert abs(output["t_1"].max() - t1_max) < tolerance @pytest.mark.usefixtures("mock_x2sys_home") @@ -159,8 +202,8 @@ def test_x2sys_cross_input_dataframe_with_nan(tracks): columns = list(output.columns) assert columns[:6] == ["x", "y", "i_1", "i_2", "dist_1", "dist_2"] assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"] - assert output.dtypes["i_1"].type == np.object_ - assert output.dtypes["i_2"].type == np.object_ + assert output.dtypes["i_1"].type == np.timedelta64 + assert output.dtypes["i_2"].type == np.timedelta64 @pytest.mark.usefixtures("mock_x2sys_home") @@ -201,7 +244,7 @@ def test_x2sys_cross_invalid_tracks_input_type(tracks): @pytest.mark.usefixtures("mock_x2sys_home") @pytest.mark.xfail( - condition=Version(__gmt_version__) < Version("6.5.0") or sys.platform == "darwin", + condition=Version(__gmt_version__) < Version("6.5.0"), reason="Upstream bug fixed in https://github.com/GenericMappingTools/gmt/pull/8188", ) def test_x2sys_cross_region_interpolation_numpoints(): @@ -222,15 +265,21 @@ def test_x2sys_cross_region_interpolation_numpoints(): ) assert isinstance(output, pd.DataFrame) - assert output.shape == (3882, 12) - # Check crossover errors (z_X) and mean value of observables (z_M) - npt.assert_allclose(output.z_X.mean(), -138.66, rtol=1e-4) - npt.assert_allclose(output.z_M.mean(), -2896.875915) + if sys.platform == "darwin": + assert output.shape == (3894, 12) + # Check crossover errors (z_X) and mean value of observables (z_M) + npt.assert_allclose(output.z_X.mean(), -138.23215, rtol=1e-4) + npt.assert_allclose(output.z_M.mean(), -2897.187545, rtol=1e-4) + else: + assert output.shape == (3882, 12) + # Check crossover errors (z_X) and mean value of observables (z_M) + npt.assert_allclose(output.z_X.mean(), -138.66, rtol=1e-4) + npt.assert_allclose(output.z_M.mean(), -2896.875915, rtol=1e-4) @pytest.mark.usefixtures("mock_x2sys_home") @pytest.mark.xfail( - condition=Version(__gmt_version__) < Version("6.5.0") or sys.platform == "darwin", + condition=Version(__gmt_version__) < Version("6.5.0"), reason="Upstream bug fixed in https://github.com/GenericMappingTools/gmt/pull/8188", ) def test_x2sys_cross_trackvalues(): @@ -243,7 +292,12 @@ def test_x2sys_cross_trackvalues(): output = x2sys_cross(tracks=["@tut_ship.xyz"], tag=tag, trackvalues=True) assert isinstance(output, pd.DataFrame) - assert output.shape == (14338, 12) - # Check mean of track 1 values (z_1) and track 2 values (z_2) - npt.assert_allclose(output.z_1.mean(), -2422.418556, rtol=1e-4) - npt.assert_allclose(output.z_2.mean(), -2402.268364, rtol=1e-4) + if sys.platform == "darwin": + assert output.shape == (14374, 12) + # Check mean of track 1 values (z_1) and track 2 values (z_2) + npt.assert_allclose(output.z_1.mean(), -2422.973372, rtol=1e-4) + npt.assert_allclose(output.z_2.mean(), -2402.87476, rtol=1e-4) + else: + assert output.shape == (14338, 12) + npt.assert_allclose(output.z_1.mean(), -2422.418556, rtol=1e-4) + npt.assert_allclose(output.z_2.mean(), -2402.268364, rtol=1e-4)