Skip to content

Commit c2684ba

Browse files
weiji14seisman
andauthored
Refactor plot and plot3d to use virtualfile_from_data (#990)
Added an additional `extra_arrays` parameter to the `virtualfile_from_data` function to accept optional numpy arrays from the plot and plot3d functions. * Just use virtualfile_from_matrix for non-datetime 2D numpy arrays More efficient to pass in whole 2D numpy array matrix as a virtualfile to GMT, and this fixes the segmentation fault crash on test_plot3d_matrix_color when the data was passed in via virtualfile_from_vectors instead. * Add docstring on extra_arrays parameter in virtualfile_from_data * Use virtualfile_from_matrix on int/uint/float types and add a test Co-authored-by: Dongdong Tian <[email protected]>
1 parent 5385fa5 commit c2684ba

File tree

4 files changed

+37
-25
lines changed

4 files changed

+37
-25
lines changed

pygmt/clib/session.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -1360,7 +1360,9 @@ def virtualfile_from_grid(self, grid):
13601360
with self.open_virtual_file(*args) as vfile:
13611361
yield vfile
13621362

1363-
def virtualfile_from_data(self, check_kind=None, data=None, x=None, y=None, z=None):
1363+
def virtualfile_from_data(
1364+
self, check_kind=None, data=None, x=None, y=None, z=None, extra_arrays=None
1365+
):
13641366
"""
13651367
Store any data inside a virtual file.
13661368
@@ -1378,6 +1380,9 @@ def virtualfile_from_data(self, check_kind=None, data=None, x=None, y=None, z=No
13781380
raster grid, a vector matrix/arrays, or other supported data input.
13791381
x/y/z : 1d arrays or None
13801382
x, y and z columns as numpy arrays.
1383+
extra_arrays : list of 1d arrays
1384+
Optional. A list of numpy arrays in addition to x, y and z. All
1385+
of these arrays must be of the same size as the x/y/z arrays.
13811386
13821387
Returns
13831388
-------
@@ -1430,14 +1435,26 @@ def virtualfile_from_data(self, check_kind=None, data=None, x=None, y=None, z=No
14301435
if kind in ("file", "grid"):
14311436
_data = (data,)
14321437
elif kind == "vectors":
1433-
_data = (x, y, z)
1438+
_data = [np.atleast_1d(x), np.atleast_1d(y)]
1439+
if z is not None:
1440+
_data.append(np.atleast_1d(z))
1441+
if extra_arrays:
1442+
_data.extend(extra_arrays)
14341443
elif kind == "matrix": # turn 2D arrays into list of vectors
14351444
try:
14361445
# pandas.DataFrame and xarray.Dataset types
14371446
_data = [array for _, array in data.items()]
14381447
except AttributeError:
1439-
# Python lists, tuples, and numpy ndarray types
1440-
_data = np.atleast_2d(np.asanyarray(data).T)
1448+
try:
1449+
# Just use virtualfile_from_matrix for 2D numpy.ndarray
1450+
# which are signed integer (i), unsigned integer (u) or
1451+
# floating point (f) types
1452+
assert data.ndim == 2 and data.dtype.kind in "iuf"
1453+
_virtualfile_from = self.virtualfile_from_matrix
1454+
_data = (data,)
1455+
except (AssertionError, AttributeError):
1456+
# Python lists, tuples, and numpy ndarray types
1457+
_data = np.atleast_2d(np.asanyarray(data).T)
14411458

14421459
# Finally create the virtualfile from the data, to be passed into GMT
14431460
file_context = _virtualfile_from(*_data)

pygmt/src/plot.py

+3-10
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
"""
22
plot - Plot in two dimensions.
33
"""
4-
import numpy as np
54
from pygmt.clib import Session
65
from pygmt.exceptions import GMTInvalidInput
76
from pygmt.helpers import (
87
build_arg_string,
98
data_kind,
10-
dummy_context,
119
fmt_docstring,
1210
is_nonstr_iter,
1311
kwargs_to_strings,
@@ -226,14 +224,9 @@ def plot(self, x=None, y=None, data=None, sizes=None, direction=None, **kwargs):
226224

227225
with Session() as lib:
228226
# Choose how data will be passed in to the module
229-
if kind == "file":
230-
file_context = dummy_context(data)
231-
elif kind == "matrix":
232-
file_context = lib.virtualfile_from_matrix(data)
233-
elif kind == "vectors":
234-
file_context = lib.virtualfile_from_vectors(
235-
np.atleast_1d(x), np.atleast_1d(y), *extra_arrays
236-
)
227+
file_context = lib.virtualfile_from_data(
228+
check_kind="vector", data=data, x=x, y=y, extra_arrays=extra_arrays
229+
)
237230

238231
with file_context as fname:
239232
arg_str = " ".join([fname, build_arg_string(kwargs)])

pygmt/src/plot3d.py

+3-10
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
"""
22
plot3d - Plot in three dimensions.
33
"""
4-
import numpy as np
54
from pygmt.clib import Session
65
from pygmt.exceptions import GMTInvalidInput
76
from pygmt.helpers import (
87
build_arg_string,
98
data_kind,
10-
dummy_context,
119
fmt_docstring,
1210
is_nonstr_iter,
1311
kwargs_to_strings,
@@ -189,14 +187,9 @@ def plot3d(
189187

190188
with Session() as lib:
191189
# Choose how data will be passed in to the module
192-
if kind == "file":
193-
file_context = dummy_context(data)
194-
elif kind == "matrix":
195-
file_context = lib.virtualfile_from_matrix(data)
196-
elif kind == "vectors":
197-
file_context = lib.virtualfile_from_vectors(
198-
np.atleast_1d(x), np.atleast_1d(y), np.atleast_1d(z), *extra_arrays
199-
)
190+
file_context = lib.virtualfile_from_data(
191+
check_kind="vector", data=data, x=x, y=y, z=z, extra_arrays=extra_arrays
192+
)
200193

201194
with file_context as fname:
202195
arg_str = " ".join([fname, build_arg_string(kwargs)])

pygmt/tests/test_info.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@ def test_info():
2929
assert output == expected_output
3030

3131

32+
def test_info_2d_list():
33+
"""
34+
Make sure info works on a 2d list.
35+
"""
36+
output = info(table=[[0, 8], [3, 5], [6, 2]])
37+
expected_output = "<vector memory>: N = 3 <0/6> <2/8>\n"
38+
assert output == expected_output
39+
40+
3241
def test_info_dataframe():
3342
"""
3443
Make sure info works on pandas.DataFrame inputs.
@@ -105,7 +114,7 @@ def test_info_2d_array():
105114
table = np.loadtxt(POINTS_DATA)
106115
output = info(table=table)
107116
expected_output = (
108-
"<vector memory>: N = 20 <11.5309/61.7074> <-2.9289/7.8648> <0.1412/0.9338>\n"
117+
"<matrix memory>: N = 20 <11.5309/61.7074> <-2.9289/7.8648> <0.1412/0.9338>\n"
109118
)
110119
assert output == expected_output
111120

0 commit comments

Comments
 (0)