Skip to content

Commit 6193938

Browse files
authored
Session.virtualfile_to_dataset: Add 'strings' output type for the array of trailing texts (#3157)
1 parent b490b0f commit 6193938

File tree

2 files changed

+48
-25
lines changed

2 files changed

+48
-25
lines changed

pygmt/clib/session.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1775,7 +1775,7 @@ def read_virtualfile(
17751775
def virtualfile_to_dataset(
17761776
self,
17771777
vfname: str,
1778-
output_type: Literal["pandas", "numpy", "file"] = "pandas",
1778+
output_type: Literal["pandas", "numpy", "file", "strings"] = "pandas",
17791779
column_names: list[str] | None = None,
17801780
dtype: type | dict[str, type] | None = None,
17811781
index_col: str | int | None = None,
@@ -1796,6 +1796,7 @@ def virtualfile_to_dataset(
17961796
- ``"pandas"`` will return a :class:`pandas.DataFrame` object.
17971797
- ``"numpy"`` will return a :class:`numpy.ndarray` object.
17981798
- ``"file"`` means the result was saved to a file and will return ``None``.
1799+
- ``"strings"`` will return the trailing text only as an array of strings.
17991800
column_names
18001801
The column names for the :class:`pandas.DataFrame` output.
18011802
dtype
@@ -1841,6 +1842,16 @@ def virtualfile_to_dataset(
18411842
... assert result is None
18421843
... assert Path(outtmp.name).stat().st_size > 0
18431844
...
1845+
... # strings output
1846+
... with Session() as lib:
1847+
... with lib.virtualfile_out(kind="dataset") as vouttbl:
1848+
... lib.call_module("read", f"{tmpfile.name} {vouttbl} -Td")
1849+
... outstr = lib.virtualfile_to_dataset(
1850+
... vfname=vouttbl, output_type="strings"
1851+
... )
1852+
... assert isinstance(outstr, np.ndarray)
1853+
... assert outstr.dtype.kind in ("S", "U")
1854+
...
18441855
... # numpy output
18451856
... with Session() as lib:
18461857
... with lib.virtualfile_out(kind="dataset") as vouttbl:
@@ -1869,6 +1880,9 @@ def virtualfile_to_dataset(
18691880
... column_names=["col1", "col2", "col3", "coltext"],
18701881
... )
18711882
... assert isinstance(outpd2, pd.DataFrame)
1883+
>>> outstr
1884+
array(['TEXT1 TEXT23', 'TEXT4 TEXT567', 'TEXT8 TEXT90',
1885+
'TEXT123 TEXT456789'], dtype='<U18')
18721886
>>> outnp
18731887
array([[1.0, 2.0, 3.0, 'TEXT1 TEXT23'],
18741888
[4.0, 5.0, 6.0, 'TEXT4 TEXT567'],
@@ -1890,11 +1904,14 @@ def virtualfile_to_dataset(
18901904
if output_type == "file": # Already written to file, so return None
18911905
return None
18921906

1893-
# Read the virtual file as a GMT dataset and convert to pandas.DataFrame
1894-
result = self.read_virtualfile(vfname, kind="dataset").contents.to_dataframe(
1895-
column_names=column_names,
1896-
dtype=dtype,
1897-
index_col=index_col,
1907+
# Read the virtual file as a _GMT_DATASET object
1908+
result = self.read_virtualfile(vfname, kind="dataset").contents
1909+
1910+
if output_type == "strings": # strings output
1911+
return result.to_strings()
1912+
1913+
result = result.to_dataframe(
1914+
column_names=column_names, dtype=dtype, index_col=index_col
18981915
)
18991916
if output_type == "numpy": # numpy.ndarray output
19001917
return result.to_numpy()

pygmt/datatypes/dataset.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,17 @@ class _GMT_DATASEGMENT(ctp.Structure): # noqa: N801
144144
("hidden", ctp.c_void_p),
145145
]
146146

147+
def to_strings(self) -> np.ndarray[Any, np.dtype[np.str_]]:
148+
"""
149+
Convert the trailing text column to an array of strings.
150+
"""
151+
textvector = []
152+
for table in self.table[: self.n_tables]:
153+
for segment in table.contents.segment[: table.contents.n_segments]:
154+
if segment.contents.text:
155+
textvector.extend(segment.contents.text[: segment.contents.n_rows])
156+
return np.char.decode(textvector) if textvector else np.array([], dtype=str)
157+
147158
def to_dataframe(
148159
self,
149160
column_names: pd.Index | None = None,
@@ -194,7 +205,11 @@ def to_dataframe(
194205
... with lib.virtualfile_out(kind="dataset") as vouttbl:
195206
... lib.call_module("read", f"{tmpfile.name} {vouttbl} -Td")
196207
... ds = lib.read_virtualfile(vouttbl, kind="dataset")
208+
... text = ds.contents.to_strings()
197209
... df = ds.contents.to_dataframe()
210+
>>> text
211+
array(['TEXT1 TEXT23', 'TEXT4 TEXT567', 'TEXT8 TEXT90',
212+
'TEXT123 TEXT456789'], dtype='<U18')
198213
>>> df
199214
0 1 2 3
200215
0 1.0 2.0 3.0 TEXT1 TEXT23
@@ -207,28 +222,19 @@ def to_dataframe(
207222
vectors = []
208223
# Deal with numeric columns
209224
for icol in range(self.n_columns):
210-
colvector = []
211-
for itbl in range(self.n_tables):
212-
dtbl = self.table[itbl].contents
213-
for iseg in range(dtbl.n_segments):
214-
dseg = dtbl.segment[iseg].contents
215-
colvector.append(
216-
np.ctypeslib.as_array(dseg.data[icol], shape=(dseg.n_rows,))
217-
)
225+
colvector = [
226+
np.ctypeslib.as_array(
227+
seg.contents.data[icol], shape=(seg.contents.n_rows,)
228+
)
229+
for tbl in self.table[: self.n_tables]
230+
for seg in tbl.contents.segment[: tbl.contents.n_segments]
231+
]
218232
vectors.append(pd.Series(data=np.concatenate(colvector)))
219233

220234
# Deal with trailing text column
221-
textvector = []
222-
for itbl in range(self.n_tables):
223-
dtbl = self.table[itbl].contents
224-
for iseg in range(dtbl.n_segments):
225-
dseg = dtbl.segment[iseg].contents
226-
if dseg.text:
227-
textvector.extend(dseg.text[: dseg.n_rows])
228-
if textvector:
229-
vectors.append(
230-
pd.Series(data=np.char.decode(textvector), dtype=pd.StringDtype())
231-
)
235+
textvector = self.to_strings()
236+
if len(textvector) != 0:
237+
vectors.append(pd.Series(data=textvector, dtype=pd.StringDtype()))
232238

233239
if len(vectors) == 0:
234240
# Return an empty DataFrame if no columns are found.

0 commit comments

Comments
 (0)