diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index 8a8b52df8e5..ba3644f0e28 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1810,6 +1810,7 @@ def virtualfile_to_dataset( self, vfname: str, output_type: Literal["pandas", "numpy", "file", "strings"] = "pandas", + header: int | None = None, column_names: list[str] | None = None, dtype: type | dict[str, type] | None = None, index_col: str | int | None = None, @@ -1831,6 +1832,10 @@ def virtualfile_to_dataset( - ``"numpy"`` will return a :class:`numpy.ndarray` object. - ``"file"`` means the result was saved to a file and will return ``None``. - ``"strings"`` will return the trailing text only as an array of strings. + header + Row number containing column names for the :class:`pandas.DataFrame` output. + ``header=None`` means not to parse the column names from table header. + Ignored if the row number is larger than the number of headers in the table. column_names The column names for the :class:`pandas.DataFrame` output. dtype @@ -1945,7 +1950,7 @@ def virtualfile_to_dataset( return result.to_strings() result = result.to_dataframe( - column_names=column_names, dtype=dtype, index_col=index_col + header=header, column_names=column_names, dtype=dtype, index_col=index_col ) if output_type == "numpy": # numpy.ndarray output return result.to_numpy() diff --git a/pygmt/datatypes/dataset.py b/pygmt/datatypes/dataset.py index daf0073aefe..e5df4a2b4a0 100644 --- a/pygmt/datatypes/dataset.py +++ b/pygmt/datatypes/dataset.py @@ -27,6 +27,7 @@ class _GMT_DATASET(ctp.Structure): # noqa: N801 >>> with GMTTempFile(suffix=".txt") as tmpfile: ... # Prepare the sample data file ... with Path(tmpfile.name).open(mode="w") as fp: + ... print("# x y z name", file=fp) ... print(">", file=fp) ... print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp) ... print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp) @@ -43,7 +44,8 @@ class _GMT_DATASET(ctp.Structure): # noqa: N801 ... print(ds.min[: ds.n_columns], ds.max[: ds.n_columns]) ... # The table ... tbl = ds.table[0].contents - ... print(tbl.n_columns, tbl.n_segments, tbl.n_records) + ... print(tbl.n_columns, tbl.n_segments, tbl.n_records, tbl.n_headers) + ... print(tbl.header[: tbl.n_headers]) ... print(tbl.min[: tbl.n_columns], ds.max[: tbl.n_columns]) ... for i in range(tbl.n_segments): ... seg = tbl.segment[i].contents @@ -52,7 +54,8 @@ class _GMT_DATASET(ctp.Structure): # noqa: N801 ... print(seg.text[: seg.n_rows]) 1 3 2 [1.0, 2.0, 3.0] [10.0, 11.0, 12.0] - 3 2 4 + 3 2 4 1 + [b'x y z name'] [1.0, 2.0, 3.0] [10.0, 11.0, 12.0] [1.0, 4.0] [2.0, 5.0] @@ -169,6 +172,7 @@ def to_strings(self) -> np.ndarray[Any, np.dtype[np.str_]]: def to_dataframe( self, + header: int | None = None, column_names: pd.Index | None = None, dtype: type | Mapping[Any, type] | None = None, index_col: str | int | None = None, @@ -187,6 +191,10 @@ def to_dataframe( ---------- column_names A list of column names. + header + Row number containing column names. ``header=None`` means not to parse the + column names from table header. Ignored if the row number is larger than the + number of headers in the table. dtype Data type. Can be a single type for all columns or a dictionary mapping column names to types. @@ -207,6 +215,7 @@ def to_dataframe( >>> with GMTTempFile(suffix=".txt") as tmpfile: ... # prepare the sample data file ... with Path(tmpfile.name).open(mode="w") as fp: + ... print("# col1 col2 col3 colstr", file=fp) ... print(">", file=fp) ... print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp) ... print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp) @@ -218,12 +227,12 @@ def to_dataframe( ... lib.call_module("read", f"{tmpfile.name} {vouttbl} -Td") ... ds = lib.read_virtualfile(vouttbl, kind="dataset") ... text = ds.contents.to_strings() - ... df = ds.contents.to_dataframe() + ... df = ds.contents.to_dataframe(header=0) >>> text array(['TEXT1 TEXT23', 'TEXT4 TEXT567', 'TEXT8 TEXT90', 'TEXT123 TEXT456789'], dtype='>> df - 0 1 2 3 + col1 col2 col3 colstr 0 1.0 2.0 3.0 TEXT1 TEXT23 1 4.0 5.0 6.0 TEXT4 TEXT567 2 7.0 8.0 9.0 TEXT8 TEXT90 @@ -248,6 +257,11 @@ def to_dataframe( if len(textvector) != 0: vectors.append(pd.Series(data=textvector, dtype=pd.StringDtype())) + if header is not None: + tbl = self.table[0].contents # Use the first table! + if header < tbl.n_headers: + column_names = tbl.header[header].decode().split() + if len(vectors) == 0: # Return an empty DataFrame if no columns are found. df = pd.DataFrame(columns=column_names) @@ -255,7 +269,7 @@ def to_dataframe( # Create a DataFrame object by concatenating multiple columns df = pd.concat(objs=vectors, axis="columns") if column_names is not None: # Assign column names - df.columns = column_names + df.columns = column_names[: df.shape[1]] if dtype is not None: # Set dtype for the whole dataset or individual columns df = df.astype(dtype) if index_col is not None: # Use a specific column as index diff --git a/pygmt/tests/test_datatypes_dataset.py b/pygmt/tests/test_datatypes_dataset.py index 6481591b22a..9576595b6b2 100644 --- a/pygmt/tests/test_datatypes_dataset.py +++ b/pygmt/tests/test_datatypes_dataset.py @@ -40,14 +40,14 @@ def dataframe_from_pandas(filepath_or_buffer, sep=r"\s+", comment="#", header=No return df -def dataframe_from_gmt(fname): +def dataframe_from_gmt(fname, **kwargs): """ Read tabular data as pandas.DataFrame using GMT virtual file. """ with Session() as lib: with lib.virtualfile_out(kind="dataset") as vouttbl: lib.call_module("read", f"{fname} {vouttbl} -Td") - df = lib.virtualfile_to_dataset(vfname=vouttbl) + df = lib.virtualfile_to_dataset(vfname=vouttbl, **kwargs) return df @@ -84,6 +84,63 @@ def test_dataset_empty(): pd.testing.assert_frame_equal(df, expected_df) +def test_dataset_header(): + """ + Test parsing column names from dataset header. + """ + with GMTTempFile(suffix=".txt") as tmpfile: + with Path(tmpfile.name).open(mode="w") as fp: + print("# lon lat z text", file=fp) + print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp) + print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp) + + # Parse columne names from the first header line + df = dataframe_from_gmt(tmpfile.name, header=0) + assert df.columns.tolist() == ["lon", "lat", "z", "text"] + # pd.read_csv() can't parse the header line with a leading '#'. + # So, we need to skip the header line and manually set the column names. + expected_df = dataframe_from_pandas(tmpfile.name, header=None) + expected_df.columns = df.columns.tolist() + pd.testing.assert_frame_equal(df, expected_df) + + +def test_dataset_header_greater_than_nheaders(): + """ + Test passing a header line number that is greater than the number of header lines. + """ + with GMTTempFile(suffix=".txt") as tmpfile: + with Path(tmpfile.name).open(mode="w") as fp: + print("# lon lat z text", file=fp) + print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp) + print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp) + + # Parse column names from the second header line. + df = dataframe_from_gmt(tmpfile.name, header=1) + # There is only one header line, so the column names should be default. + assert df.columns.tolist() == [0, 1, 2, 3] + expected_df = dataframe_from_pandas(tmpfile.name, header=None) + pd.testing.assert_frame_equal(df, expected_df) + + +def test_dataset_header_too_many_names(): + """ + Test passing a header line with more column names than the number of columns. + """ + with GMTTempFile(suffix=".txt") as tmpfile: + with Path(tmpfile.name).open(mode="w") as fp: + print("# lon lat z text1 text2", file=fp) + print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp) + print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp) + + df = dataframe_from_gmt(tmpfile.name, header=0) + assert df.columns.tolist() == ["lon", "lat", "z", "text1"] + # pd.read_csv() can't parse the header line with a leading '#'. + # So, we need to skip the header line and manually set the column names. + expected_df = dataframe_from_pandas(tmpfile.name, header=None) + expected_df.columns = df.columns.tolist() + pd.testing.assert_frame_equal(df, expected_df) + + def test_dataset_to_strings_with_none_values(): """ Test that None values in the trailing text doesn't raise an exception.