diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index d7f1b585f8d..069a762001d 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -5,6 +5,7 @@ Uses ctypes to wrap most of the core functions from the C API. """ import ctypes as ctp +import pathlib import sys from contextlib import contextmanager, nullcontext @@ -1474,6 +1475,7 @@ def virtualfile_from_data( z=None, extra_arrays=None, required_z=False, + required_data=True, ): """ Store any data inside a virtual file. @@ -1484,7 +1486,7 @@ def virtualfile_from_data( Parameters ---------- - check_kind : str + check_kind : str or None Used to validate the type of data that can be passed in. Choose from 'raster', 'vector', or None. Default is None (no validation). data : str or pathlib.Path or xarray.DataArray or {table-like} or None @@ -1498,6 +1500,9 @@ def virtualfile_from_data( All of these arrays must be of the same size as the x/y/z arrays. required_z : bool State whether the 'z' column is required. + required_data : bool + Set to True when 'data' is required, or False when dealing with + optional virtual files. [Default is True]. Returns ------- @@ -1528,21 +1533,25 @@ def virtualfile_from_data( ... : N = 3 <7/9> <4/6> <1/3> """ - kind = data_kind(data, x, y, z, required_z=required_z) - - if check_kind == "raster" and kind not in ("file", "grid"): - raise GMTInvalidInput(f"Unrecognized data type for grid: {type(data)}") - if check_kind == "vector" and kind not in ( - "file", - "matrix", - "vectors", - "geojson", - ): - raise GMTInvalidInput(f"Unrecognized data type for vector: {type(data)}") + kind = data_kind( + data, x, y, z, required_z=required_z, required_data=required_data + ) + + if check_kind: + valid_kinds = ("file", "arg") if required_data is False else ("file",) + if check_kind == "raster": + valid_kinds += ("grid",) + elif check_kind == "vector": + valid_kinds += ("matrix", "vectors", "geojson") + if kind not in valid_kinds: + raise GMTInvalidInput( + f"Unrecognized data type for {check_kind}: {type(data)}" + ) # Decide which virtualfile_from_ function to use _virtualfile_from = { "file": nullcontext, + "arg": nullcontext, "geojson": tempfile_from_geojson, "grid": self.virtualfile_from_grid, # Note: virtualfile_from_matrix is not used because a matrix can be @@ -1553,11 +1562,8 @@ def virtualfile_from_data( }[kind] # Ensure the data is an iterable (Python list or tuple) - if kind in ("geojson", "grid"): - _data = (data,) - elif kind == "file": - # Useful to handle `pathlib.Path` and string file path alike - _data = (str(data),) + if kind in ("geojson", "grid", "file", "arg"): + _data = (data,) if not isinstance(data, pathlib.PurePath) else (str(data),) elif kind == "vectors": _data = [np.atleast_1d(x), np.atleast_1d(y)] if z is not None: diff --git a/pygmt/helpers/utils.py b/pygmt/helpers/utils.py index 286f0430c9c..e52197012be 100644 --- a/pygmt/helpers/utils.py +++ b/pygmt/helpers/utils.py @@ -15,7 +15,7 @@ def _validate_data_input( - data=None, x=None, y=None, z=None, required_z=False, kind=None + data=None, x=None, y=None, z=None, required_z=False, required_data=True, kind=None ): """ Check if the combination of data/x/y/z is valid. @@ -25,6 +25,7 @@ def _validate_data_input( >>> _validate_data_input(data="infile") >>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6]) >>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], z=[7, 8, 9]) + >>> _validate_data_input(data=None, required_data=False) >>> _validate_data_input() Traceback (most recent call last): ... @@ -41,6 +42,30 @@ def _validate_data_input( Traceback (most recent call last): ... pygmt.exceptions.GMTInvalidInput: Must provide x, y, and z. + >>> import numpy as np + >>> import pandas as pd + >>> import xarray as xr + >>> data = np.arange(8).reshape((4, 2)) + >>> _validate_data_input(data=data, required_z=True, kind="matrix") + Traceback (most recent call last): + ... + pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns. + >>> _validate_data_input( + ... data=pd.DataFrame(data, columns=["x", "y"]), + ... required_z=True, + ... kind="matrix", + ... ) + Traceback (most recent call last): + ... + pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns. + >>> _validate_data_input( + ... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])), + ... required_z=True, + ... kind="matrix", + ... ) + Traceback (most recent call last): + ... + pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns. >>> _validate_data_input(data="infile", x=[1, 2, 3]) Traceback (most recent call last): ... @@ -61,11 +86,11 @@ def _validate_data_input( """ if data is None: # data is None if x is None and y is None: # both x and y are None - raise GMTInvalidInput("No input data provided.") - if x is None or y is None: # either x or y is None + if required_data: # data is not optional + raise GMTInvalidInput("No input data provided.") + elif x is None or y is None: # either x or y is None raise GMTInvalidInput("Must provide both x and y.") - # both x and y are not None, now check z - if required_z and z is None: + if required_z and z is None: # both x and y are not None, now check z raise GMTInvalidInput("Must provide x, y, and z.") else: # data is not None if x is not None or y is not None or z is not None: @@ -81,38 +106,43 @@ def _validate_data_input( raise GMTInvalidInput("data must provide x, y, and z columns.") -def data_kind(data=None, x=None, y=None, z=None, required_z=False): +def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data=True): """ Check what kind of data is provided to a module. Possible types: * a file name provided as 'data' - * a pathlib.Path provided as 'data' - * an xarray.DataArray provided as 'data' - * a matrix provided as 'data' + * a pathlib.PurePath object provided as 'data' + * an xarray.DataArray object provided as 'data' + * a 2-D matrix provided as 'data' * 1-D arrays x and y (and z, optionally) + * an optional argument (None, bool, int or float) provided as 'data' Arguments should be ``None`` if not used. If doesn't fit any of these categories (or fits more than one), will raise an exception. Parameters ---------- - data : str or pathlib.Path or xarray.DataArray or {table-like} or None + data : str, pathlib.PurePath, None, bool, xarray.DataArray or {table-like} Pass in either a file name or :class:`pathlib.Path` to an ASCII data table, an :class:`xarray.DataArray`, a 1-D/2-D - {table-classes}. + {table-classes} or an option argument. x/y : 1-D arrays or None x and y columns as numpy arrays. z : 1-D array or None z column as numpy array. To be used optionally when x and y are given. required_z : bool State whether the 'z' column is required. + required_data : bool + Set to True when 'data' is required, or False when dealing with + optional virtual files. [Default is True]. Returns ------- kind : str - One of: ``'file'``, ``'grid'``, ``'matrix'``, ``'vectors'``. + One of ``'arg'``, ``'file'``, ``'grid'``, ``'geojson'``, ``'matrix'``, + or ``'vectors'``. Examples -------- @@ -128,20 +158,39 @@ def data_kind(data=None, x=None, y=None, z=None, required_z=False): 'file' >>> data_kind(data=pathlib.Path("my-data-file.txt"), x=None, y=None) 'file' + >>> data_kind(data=None, x=None, y=None, required_data=False) + 'arg' + >>> data_kind(data=2.0, x=None, y=None, required_data=False) + 'arg' + >>> data_kind(data=True, x=None, y=None, required_data=False) + 'arg' >>> data_kind(data=xr.DataArray(np.random.rand(4, 3))) 'grid' """ + # determine the data kind if isinstance(data, (str, pathlib.PurePath)): kind = "file" + elif isinstance(data, (bool, int, float)) or (data is None and not required_data): + kind = "arg" elif isinstance(data, xr.DataArray): kind = "grid" elif hasattr(data, "__geo_interface__"): + # geo-like Python object that implements ``__geo_interface__`` + # (geopandas.GeoDataFrame or shapely.geometry) kind = "geojson" elif data is not None: kind = "matrix" else: kind = "vectors" - _validate_data_input(data=data, x=x, y=y, z=z, required_z=required_z, kind=kind) + _validate_data_input( + data=data, + x=x, + y=y, + z=z, + required_z=required_z, + required_data=required_data, + kind=kind, + ) return kind diff --git a/pygmt/src/dimfilter.py b/pygmt/src/dimfilter.py index ba4e39ce1bd..ddfd1d31e0c 100644 --- a/pygmt/src/dimfilter.py +++ b/pygmt/src/dimfilter.py @@ -154,7 +154,7 @@ def dimfilter(grid, **kwargs): with GMTTempFile(suffix=".nc") as tmpfile: with Session() as lib: - file_context = lib.virtualfile_from_data(check_kind=None, data=grid) + file_context = lib.virtualfile_from_data(check_kind="raster", data=grid) with file_context as infile: if (outgrid := kwargs.get("G")) is None: kwargs["G"] = outgrid = tmpfile.name # output to tmpfile diff --git a/pygmt/src/grdimage.py b/pygmt/src/grdimage.py index 93e1fffb2c4..d036453985e 100644 --- a/pygmt/src/grdimage.py +++ b/pygmt/src/grdimage.py @@ -1,16 +1,8 @@ """ grdimage - Plot grids or images. """ -import contextlib - from pygmt.clib import Session -from pygmt.helpers import ( - build_arg_string, - data_kind, - fmt_docstring, - kwargs_to_strings, - use_alias, -) +from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias __doctest_skip__ = ["grdimage"] @@ -180,16 +172,12 @@ def grdimage(self, grid, **kwargs): """ kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access with Session() as lib: - file_context = lib.virtualfile_from_data(check_kind="raster", data=grid) - with contextlib.ExitStack() as stack: - # shading using an xr.DataArray - if kwargs.get("I") is not None and data_kind(kwargs["I"]) == "grid": - shading_context = lib.virtualfile_from_data( - check_kind="raster", data=kwargs["I"] - ) - kwargs["I"] = stack.enter_context(shading_context) - - fname = stack.enter_context(file_context) + with lib.virtualfile_from_data( + check_kind="raster", data=grid + ) as fname, lib.virtualfile_from_data( + check_kind="raster", data=kwargs.get("I"), required_data=False + ) as shadegrid: + kwargs["I"] = shadegrid lib.call_module( module="grdimage", args=build_arg_string(kwargs, infile=fname) ) diff --git a/pygmt/src/grdview.py b/pygmt/src/grdview.py index bd8be224eb8..57c8840ccca 100644 --- a/pygmt/src/grdview.py +++ b/pygmt/src/grdview.py @@ -1,17 +1,8 @@ """ grdview - Create a three-dimensional plot from a grid. """ -import contextlib - from pygmt.clib import Session -from pygmt.exceptions import GMTInvalidInput -from pygmt.helpers import ( - build_arg_string, - data_kind, - fmt_docstring, - kwargs_to_strings, - use_alias, -) +from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias __doctest_skip__ = ["grdview"] @@ -155,23 +146,12 @@ def grdview(self, grid, **kwargs): """ kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access with Session() as lib: - file_context = lib.virtualfile_from_data(check_kind="raster", data=grid) - - with contextlib.ExitStack() as stack: - if kwargs.get("G") is not None: - # deal with kwargs["G"] if drapegrid is xr.DataArray - drapegrid = kwargs["G"] - if data_kind(drapegrid) in ("file", "grid"): - if data_kind(drapegrid) == "grid": - drape_context = lib.virtualfile_from_data( - check_kind="raster", data=drapegrid - ) - kwargs["G"] = stack.enter_context(drape_context) - else: - raise GMTInvalidInput( - f"Unrecognized data type for drapegrid: {type(drapegrid)}" - ) - fname = stack.enter_context(file_context) + with lib.virtualfile_from_data( + check_kind="raster", data=grid + ) as fname, lib.virtualfile_from_data( + check_kind="raster", data=kwargs.get("G"), required_data=False + ) as drapegrid: + kwargs["G"] = drapegrid lib.call_module( module="grdview", args=build_arg_string(kwargs, infile=fname) ) diff --git a/pygmt/tests/test_clib.py b/pygmt/tests/test_clib.py index 6ad698c13ae..38a77f5f064 100644 --- a/pygmt/tests/test_clib.py +++ b/pygmt/tests/test_clib.py @@ -439,7 +439,9 @@ def test_virtualfile_from_data_required_z_matrix(array_func, kind): ) data = array_func(dataframe) with clib.Session() as lib: - with lib.virtualfile_from_data(data=data, required_z=True) as vfile: + with lib.virtualfile_from_data( + data=data, required_z=True, check_kind="vector" + ) as vfile: with GMTTempFile() as outfile: lib.call_module("info", f"{vfile} ->{outfile.name}") output = outfile.read(keep_tabs=True) @@ -461,7 +463,9 @@ def test_virtualfile_from_data_required_z_matrix_missing(): data = np.ones((5, 2)) with clib.Session() as lib: with pytest.raises(GMTInvalidInput): - with lib.virtualfile_from_data(data=data, required_z=True): + with lib.virtualfile_from_data( + data=data, required_z=True, check_kind="vector" + ): pass