Skip to content

Commit 109f209

Browse files
seismanweiji14
andauthored
Better handling of optional virtual files (e.g., shading in Figure.grdimage) (#2493)
Co-authored-by: Wei Ji <[email protected]>
1 parent d580dff commit 109f209

File tree

6 files changed

+106
-79
lines changed

6 files changed

+106
-79
lines changed

pygmt/clib/session.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Uses ctypes to wrap most of the core functions from the C API.
66
"""
77
import ctypes as ctp
8+
import pathlib
89
import sys
910
from contextlib import contextmanager, nullcontext
1011

@@ -1474,6 +1475,7 @@ def virtualfile_from_data(
14741475
z=None,
14751476
extra_arrays=None,
14761477
required_z=False,
1478+
required_data=True,
14771479
):
14781480
"""
14791481
Store any data inside a virtual file.
@@ -1484,7 +1486,7 @@ def virtualfile_from_data(
14841486
14851487
Parameters
14861488
----------
1487-
check_kind : str
1489+
check_kind : str or None
14881490
Used to validate the type of data that can be passed in. Choose
14891491
from 'raster', 'vector', or None. Default is None (no validation).
14901492
data : str or pathlib.Path or xarray.DataArray or {table-like} or None
@@ -1498,6 +1500,9 @@ def virtualfile_from_data(
14981500
All of these arrays must be of the same size as the x/y/z arrays.
14991501
required_z : bool
15001502
State whether the 'z' column is required.
1503+
required_data : bool
1504+
Set to True when 'data' is required, or False when dealing with
1505+
optional virtual files. [Default is True].
15011506
15021507
Returns
15031508
-------
@@ -1528,21 +1533,25 @@ def virtualfile_from_data(
15281533
...
15291534
<vector memory>: N = 3 <7/9> <4/6> <1/3>
15301535
"""
1531-
kind = data_kind(data, x, y, z, required_z=required_z)
1532-
1533-
if check_kind == "raster" and kind not in ("file", "grid"):
1534-
raise GMTInvalidInput(f"Unrecognized data type for grid: {type(data)}")
1535-
if check_kind == "vector" and kind not in (
1536-
"file",
1537-
"matrix",
1538-
"vectors",
1539-
"geojson",
1540-
):
1541-
raise GMTInvalidInput(f"Unrecognized data type for vector: {type(data)}")
1536+
kind = data_kind(
1537+
data, x, y, z, required_z=required_z, required_data=required_data
1538+
)
1539+
1540+
if check_kind:
1541+
valid_kinds = ("file", "arg") if required_data is False else ("file",)
1542+
if check_kind == "raster":
1543+
valid_kinds += ("grid",)
1544+
elif check_kind == "vector":
1545+
valid_kinds += ("matrix", "vectors", "geojson")
1546+
if kind not in valid_kinds:
1547+
raise GMTInvalidInput(
1548+
f"Unrecognized data type for {check_kind}: {type(data)}"
1549+
)
15421550

15431551
# Decide which virtualfile_from_ function to use
15441552
_virtualfile_from = {
15451553
"file": nullcontext,
1554+
"arg": nullcontext,
15461555
"geojson": tempfile_from_geojson,
15471556
"grid": self.virtualfile_from_grid,
15481557
# Note: virtualfile_from_matrix is not used because a matrix can be
@@ -1553,11 +1562,8 @@ def virtualfile_from_data(
15531562
}[kind]
15541563

15551564
# Ensure the data is an iterable (Python list or tuple)
1556-
if kind in ("geojson", "grid"):
1557-
_data = (data,)
1558-
elif kind == "file":
1559-
# Useful to handle `pathlib.Path` and string file path alike
1560-
_data = (str(data),)
1565+
if kind in ("geojson", "grid", "file", "arg"):
1566+
_data = (data,) if not isinstance(data, pathlib.PurePath) else (str(data),)
15611567
elif kind == "vectors":
15621568
_data = [np.atleast_1d(x), np.atleast_1d(y)]
15631569
if z is not None:

pygmt/helpers/utils.py

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
def _validate_data_input(
18-
data=None, x=None, y=None, z=None, required_z=False, kind=None
18+
data=None, x=None, y=None, z=None, required_z=False, required_data=True, kind=None
1919
):
2020
"""
2121
Check if the combination of data/x/y/z is valid.
@@ -25,6 +25,7 @@ def _validate_data_input(
2525
>>> _validate_data_input(data="infile")
2626
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6])
2727
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], z=[7, 8, 9])
28+
>>> _validate_data_input(data=None, required_data=False)
2829
>>> _validate_data_input()
2930
Traceback (most recent call last):
3031
...
@@ -41,6 +42,30 @@ def _validate_data_input(
4142
Traceback (most recent call last):
4243
...
4344
pygmt.exceptions.GMTInvalidInput: Must provide x, y, and z.
45+
>>> import numpy as np
46+
>>> import pandas as pd
47+
>>> import xarray as xr
48+
>>> data = np.arange(8).reshape((4, 2))
49+
>>> _validate_data_input(data=data, required_z=True, kind="matrix")
50+
Traceback (most recent call last):
51+
...
52+
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
53+
>>> _validate_data_input(
54+
... data=pd.DataFrame(data, columns=["x", "y"]),
55+
... required_z=True,
56+
... kind="matrix",
57+
... )
58+
Traceback (most recent call last):
59+
...
60+
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
61+
>>> _validate_data_input(
62+
... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])),
63+
... required_z=True,
64+
... kind="matrix",
65+
... )
66+
Traceback (most recent call last):
67+
...
68+
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
4469
>>> _validate_data_input(data="infile", x=[1, 2, 3])
4570
Traceback (most recent call last):
4671
...
@@ -61,11 +86,11 @@ def _validate_data_input(
6186
"""
6287
if data is None: # data is None
6388
if x is None and y is None: # both x and y are None
64-
raise GMTInvalidInput("No input data provided.")
65-
if x is None or y is None: # either x or y is None
89+
if required_data: # data is not optional
90+
raise GMTInvalidInput("No input data provided.")
91+
elif x is None or y is None: # either x or y is None
6692
raise GMTInvalidInput("Must provide both x and y.")
67-
# both x and y are not None, now check z
68-
if required_z and z is None:
93+
if required_z and z is None: # both x and y are not None, now check z
6994
raise GMTInvalidInput("Must provide x, y, and z.")
7095
else: # data is not None
7196
if x is not None or y is not None or z is not None:
@@ -81,38 +106,43 @@ def _validate_data_input(
81106
raise GMTInvalidInput("data must provide x, y, and z columns.")
82107

83108

84-
def data_kind(data=None, x=None, y=None, z=None, required_z=False):
109+
def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data=True):
85110
"""
86111
Check what kind of data is provided to a module.
87112
88113
Possible types:
89114
90115
* a file name provided as 'data'
91-
* a pathlib.Path provided as 'data'
92-
* an xarray.DataArray provided as 'data'
93-
* a matrix provided as 'data'
116+
* a pathlib.PurePath object provided as 'data'
117+
* an xarray.DataArray object provided as 'data'
118+
* a 2-D matrix provided as 'data'
94119
* 1-D arrays x and y (and z, optionally)
120+
* an optional argument (None, bool, int or float) provided as 'data'
95121
96122
Arguments should be ``None`` if not used. If doesn't fit any of these
97123
categories (or fits more than one), will raise an exception.
98124
99125
Parameters
100126
----------
101-
data : str or pathlib.Path or xarray.DataArray or {table-like} or None
127+
data : str, pathlib.PurePath, None, bool, xarray.DataArray or {table-like}
102128
Pass in either a file name or :class:`pathlib.Path` to an ASCII data
103129
table, an :class:`xarray.DataArray`, a 1-D/2-D
104-
{table-classes}.
130+
{table-classes} or an option argument.
105131
x/y : 1-D arrays or None
106132
x and y columns as numpy arrays.
107133
z : 1-D array or None
108134
z column as numpy array. To be used optionally when x and y are given.
109135
required_z : bool
110136
State whether the 'z' column is required.
137+
required_data : bool
138+
Set to True when 'data' is required, or False when dealing with
139+
optional virtual files. [Default is True].
111140
112141
Returns
113142
-------
114143
kind : str
115-
One of: ``'file'``, ``'grid'``, ``'matrix'``, ``'vectors'``.
144+
One of ``'arg'``, ``'file'``, ``'grid'``, ``'geojson'``, ``'matrix'``,
145+
or ``'vectors'``.
116146
117147
Examples
118148
--------
@@ -128,20 +158,39 @@ def data_kind(data=None, x=None, y=None, z=None, required_z=False):
128158
'file'
129159
>>> data_kind(data=pathlib.Path("my-data-file.txt"), x=None, y=None)
130160
'file'
161+
>>> data_kind(data=None, x=None, y=None, required_data=False)
162+
'arg'
163+
>>> data_kind(data=2.0, x=None, y=None, required_data=False)
164+
'arg'
165+
>>> data_kind(data=True, x=None, y=None, required_data=False)
166+
'arg'
131167
>>> data_kind(data=xr.DataArray(np.random.rand(4, 3)))
132168
'grid'
133169
"""
170+
# determine the data kind
134171
if isinstance(data, (str, pathlib.PurePath)):
135172
kind = "file"
173+
elif isinstance(data, (bool, int, float)) or (data is None and not required_data):
174+
kind = "arg"
136175
elif isinstance(data, xr.DataArray):
137176
kind = "grid"
138177
elif hasattr(data, "__geo_interface__"):
178+
# geo-like Python object that implements ``__geo_interface__``
179+
# (geopandas.GeoDataFrame or shapely.geometry)
139180
kind = "geojson"
140181
elif data is not None:
141182
kind = "matrix"
142183
else:
143184
kind = "vectors"
144-
_validate_data_input(data=data, x=x, y=y, z=z, required_z=required_z, kind=kind)
185+
_validate_data_input(
186+
data=data,
187+
x=x,
188+
y=y,
189+
z=z,
190+
required_z=required_z,
191+
required_data=required_data,
192+
kind=kind,
193+
)
145194
return kind
146195

147196

pygmt/src/dimfilter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def dimfilter(grid, **kwargs):
154154

155155
with GMTTempFile(suffix=".nc") as tmpfile:
156156
with Session() as lib:
157-
file_context = lib.virtualfile_from_data(check_kind=None, data=grid)
157+
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
158158
with file_context as infile:
159159
if (outgrid := kwargs.get("G")) is None:
160160
kwargs["G"] = outgrid = tmpfile.name # output to tmpfile

pygmt/src/grdimage.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,8 @@
11
"""
22
grdimage - Plot grids or images.
33
"""
4-
import contextlib
5-
64
from pygmt.clib import Session
7-
from pygmt.helpers import (
8-
build_arg_string,
9-
data_kind,
10-
fmt_docstring,
11-
kwargs_to_strings,
12-
use_alias,
13-
)
5+
from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias
146

157
__doctest_skip__ = ["grdimage"]
168

@@ -180,16 +172,12 @@ def grdimage(self, grid, **kwargs):
180172
"""
181173
kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access
182174
with Session() as lib:
183-
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
184-
with contextlib.ExitStack() as stack:
185-
# shading using an xr.DataArray
186-
if kwargs.get("I") is not None and data_kind(kwargs["I"]) == "grid":
187-
shading_context = lib.virtualfile_from_data(
188-
check_kind="raster", data=kwargs["I"]
189-
)
190-
kwargs["I"] = stack.enter_context(shading_context)
191-
192-
fname = stack.enter_context(file_context)
175+
with lib.virtualfile_from_data(
176+
check_kind="raster", data=grid
177+
) as fname, lib.virtualfile_from_data(
178+
check_kind="raster", data=kwargs.get("I"), required_data=False
179+
) as shadegrid:
180+
kwargs["I"] = shadegrid
193181
lib.call_module(
194182
module="grdimage", args=build_arg_string(kwargs, infile=fname)
195183
)

pygmt/src/grdview.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,8 @@
11
"""
22
grdview - Create a three-dimensional plot from a grid.
33
"""
4-
import contextlib
5-
64
from pygmt.clib import Session
7-
from pygmt.exceptions import GMTInvalidInput
8-
from pygmt.helpers import (
9-
build_arg_string,
10-
data_kind,
11-
fmt_docstring,
12-
kwargs_to_strings,
13-
use_alias,
14-
)
5+
from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias
156

167
__doctest_skip__ = ["grdview"]
178

@@ -155,23 +146,12 @@ def grdview(self, grid, **kwargs):
155146
"""
156147
kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access
157148
with Session() as lib:
158-
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
159-
160-
with contextlib.ExitStack() as stack:
161-
if kwargs.get("G") is not None:
162-
# deal with kwargs["G"] if drapegrid is xr.DataArray
163-
drapegrid = kwargs["G"]
164-
if data_kind(drapegrid) in ("file", "grid"):
165-
if data_kind(drapegrid) == "grid":
166-
drape_context = lib.virtualfile_from_data(
167-
check_kind="raster", data=drapegrid
168-
)
169-
kwargs["G"] = stack.enter_context(drape_context)
170-
else:
171-
raise GMTInvalidInput(
172-
f"Unrecognized data type for drapegrid: {type(drapegrid)}"
173-
)
174-
fname = stack.enter_context(file_context)
149+
with lib.virtualfile_from_data(
150+
check_kind="raster", data=grid
151+
) as fname, lib.virtualfile_from_data(
152+
check_kind="raster", data=kwargs.get("G"), required_data=False
153+
) as drapegrid:
154+
kwargs["G"] = drapegrid
175155
lib.call_module(
176156
module="grdview", args=build_arg_string(kwargs, infile=fname)
177157
)

pygmt/tests/test_clib.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,9 @@ def test_virtualfile_from_data_required_z_matrix(array_func, kind):
439439
)
440440
data = array_func(dataframe)
441441
with clib.Session() as lib:
442-
with lib.virtualfile_from_data(data=data, required_z=True) as vfile:
442+
with lib.virtualfile_from_data(
443+
data=data, required_z=True, check_kind="vector"
444+
) as vfile:
443445
with GMTTempFile() as outfile:
444446
lib.call_module("info", f"{vfile} ->{outfile.name}")
445447
output = outfile.read(keep_tabs=True)
@@ -461,7 +463,9 @@ def test_virtualfile_from_data_required_z_matrix_missing():
461463
data = np.ones((5, 2))
462464
with clib.Session() as lib:
463465
with pytest.raises(GMTInvalidInput):
464-
with lib.virtualfile_from_data(data=data, required_z=True):
466+
with lib.virtualfile_from_data(
467+
data=data, required_z=True, check_kind="vector"
468+
):
465469
pass
466470

467471

0 commit comments

Comments
 (0)