Skip to content

Commit 53eb1b5

Browse files
authored
Allow pandas.DataFrame table and 1D/2D numpy array inputs into pygmt.info (GenericMappingTools#574)
Also renamed 'fname' argument to 'table' since `info` supports both file name inputs and pandas.DataFrame tables now.
1 parent 4856d64 commit 53eb1b5

File tree

2 files changed

+63
-20
lines changed

2 files changed

+63
-20
lines changed

pygmt/modules.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Non-plot GMT modules.
33
"""
4+
import numpy as np
45
import xarray as xr
56

67
from .clib import Session
@@ -55,7 +56,7 @@ def grdinfo(grid, **kwargs):
5556

5657
@fmt_docstring
5758
@use_alias(C="per_column", I="spacing", T="nearest_multiple")
58-
def info(fname, **kwargs):
59+
def info(table, **kwargs):
5960
"""
6061
Get information about data tables.
6162
@@ -74,8 +75,9 @@ def info(fname, **kwargs):
7475
7576
Parameters
7677
----------
77-
fname : str
78-
The file name of the input data table file.
78+
table : pandas.DataFrame or np.ndarray or str
79+
Either a pandas dataframe, a 1D/2D numpy.ndarray or a file name to an
80+
ASCII data table.
7981
per_column : bool
8082
Report the min/max values per column in separate columns.
8183
spacing : str
@@ -88,14 +90,25 @@ def info(fname, **kwargs):
8890
Report the min/max of the first (0'th) column to the nearest multiple
8991
of dz and output this as the string *-Tzmin/zmax/dz*.
9092
"""
91-
if not isinstance(fname, str):
92-
raise GMTInvalidInput("'info' only accepts file names.")
93+
kind = data_kind(table)
94+
with Session() as lib:
95+
if kind == "file":
96+
file_context = dummy_context(table)
97+
elif kind == "matrix":
98+
_table = np.asanyarray(table)
99+
if table.ndim == 1: # 1D arrays need to be 2D and transposed
100+
_table = np.transpose(np.atleast_2d(_table))
101+
file_context = lib.virtualfile_from_matrix(_table)
102+
else:
103+
raise GMTInvalidInput(f"Unrecognized data type: {type(table)}")
93104

94-
with GMTTempFile() as tmpfile:
95-
arg_str = " ".join([fname, build_arg_string(kwargs), "->" + tmpfile.name])
96-
with Session() as lib:
97-
lib.call_module("info", arg_str)
98-
return tmpfile.read()
105+
with GMTTempFile() as tmpfile:
106+
with file_context as fname:
107+
arg_str = " ".join(
108+
[fname, build_arg_string(kwargs), "->" + tmpfile.name]
109+
)
110+
lib.call_module("info", arg_str)
111+
return tmpfile.read()
99112

100113

101114
@fmt_docstring

pygmt/tests/test_info.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import os
55

66
import numpy as np
7+
import pandas as pd
78
import pytest
9+
import xarray as xr
810

911
from .. import info
1012
from ..exceptions import GMTInvalidInput
@@ -14,8 +16,8 @@
1416

1517

1618
def test_info():
17-
"Make sure info works"
18-
output = info(fname=POINTS_DATA)
19+
"Make sure info works on file name inputs"
20+
output = info(table=POINTS_DATA)
1921
expected_output = (
2022
f"{POINTS_DATA}: N = 20 "
2123
"<11.5309/61.7074> "
@@ -25,33 +27,61 @@ def test_info():
2527
assert output == expected_output
2628

2729

30+
def test_info_dataframe():
31+
"Make sure info works on pandas.DataFrame inputs"
32+
table = pd.read_csv(POINTS_DATA, sep=" ", header=None)
33+
output = info(table=table)
34+
expected_output = (
35+
"<matrix memory>: N = 20 <11.5309/61.7074> <-2.9289/7.8648> <0.1412/0.9338>\n"
36+
)
37+
assert output == expected_output
38+
39+
40+
def test_info_2d_array():
41+
"Make sure info works on 2D numpy.ndarray inputs"
42+
table = np.loadtxt(POINTS_DATA)
43+
output = info(table=table)
44+
expected_output = (
45+
"<matrix memory>: N = 20 <11.5309/61.7074> <-2.9289/7.8648> <0.1412/0.9338>\n"
46+
)
47+
assert output == expected_output
48+
49+
50+
def test_info_1d_array():
51+
"Make sure info works on 1D numpy.ndarray inputs"
52+
output = info(table=np.arange(20))
53+
expected_output = "<matrix memory>: N = 20 <0/19>\n"
54+
assert output == expected_output
55+
56+
2857
def test_info_per_column():
2958
"Make sure the per_column option works"
30-
output = info(fname=POINTS_DATA, per_column=True)
59+
output = info(table=POINTS_DATA, per_column=True)
3160
assert output == "11.5309 61.7074 -2.9289 7.8648 0.1412 0.9338\n"
3261

3362

3463
def test_info_spacing():
3564
"Make sure the spacing option works"
36-
output = info(fname=POINTS_DATA, spacing=0.1)
65+
output = info(table=POINTS_DATA, spacing=0.1)
3766
assert output == "-R11.5/61.8/-3/7.9\n"
3867

3968

4069
def test_info_per_column_spacing():
4170
"Make sure the per_column and spacing options work together"
42-
output = info(fname=POINTS_DATA, per_column=True, spacing=0.1)
71+
output = info(table=POINTS_DATA, per_column=True, spacing=0.1)
4372
assert output == "11.5 61.8 -3 7.9 0.1412 0.9338\n"
4473

4574

4675
def test_info_nearest_multiple():
4776
"Make sure the nearest_multiple option works"
48-
output = info(fname=POINTS_DATA, nearest_multiple=0.1)
77+
output = info(table=POINTS_DATA, nearest_multiple=0.1)
4978
assert output == "-T11.5/61.8/0.1\n"
5079

5180

5281
def test_info_fails():
53-
"Make sure info raises an exception if not given a file name"
54-
with pytest.raises(GMTInvalidInput):
55-
info(fname=21)
82+
"""
83+
Make sure info raises an exception if not given either a file name, pandas
84+
DataFrame, or numpy ndarray
85+
"""
5686
with pytest.raises(GMTInvalidInput):
57-
info(fname=np.arange(20))
87+
info(table=xr.DataArray(21))

0 commit comments

Comments
 (0)