diff --git a/doc/api/index.rst b/doc/api/index.rst index 25de6d44adf..f1d480655b7 100644 --- a/doc/api/index.rst +++ b/doc/api/index.rst @@ -197,6 +197,14 @@ Getting metadata from tabular or grid data: info grdinfo +Xarray Integration +------------------ + +.. autosummary:: + :toctree: generated + + GMTBackendEntrypoint + Enums ----- diff --git a/pygmt/__init__.py b/pygmt/__init__.py index f6d1040851f..d1f81094301 100644 --- a/pygmt/__init__.py +++ b/pygmt/__init__.py @@ -65,6 +65,7 @@ x2sys_init, xyz2grd, ) +from pygmt.xarray import GMTBackendEntrypoint # Start our global modern mode session _begin() diff --git a/pygmt/tests/test_xarray_backend.py b/pygmt/tests/test_xarray_backend.py new file mode 100644 index 00000000000..b04128bce07 --- /dev/null +++ b/pygmt/tests/test_xarray_backend.py @@ -0,0 +1,74 @@ +""" +Tests for xarray 'gmt' backend engine. +""" + +import re + +import numpy as np +import numpy.testing as npt +import pytest +import xarray as xr +from pygmt.enums import GridRegistration, GridType +from pygmt.exceptions import GMTInvalidInput + + +def test_xarray_backend_gmt_open_nc_grid(): + """ + Ensure that passing engine='gmt' to xarray.open_dataarray works for opening NetCDF + grids. + """ + with xr.open_dataarray( + "@static_earth_relief.nc", engine="gmt", raster_kind="grid" + ) as da: + assert da.sizes == {"lat": 14, "lon": 8} + assert da.dtype == "float32" + assert da.gmt.registration == GridRegistration.PIXEL + assert da.gmt.gtype == GridType.GEOGRAPHIC + + +def test_xarray_backend_gmt_open_tif_image(): + """ + Ensure that passing engine='gmt' to xarray.open_dataarray works for opening GeoTIFF + images. + """ + with xr.open_dataarray("@earth_day_01d", engine="gmt", raster_kind="image") as da: + assert da.sizes == {"band": 3, "y": 180, "x": 360} + assert da.dtype == "uint8" + assert da.gmt.registration == GridRegistration.PIXEL + assert da.gmt.gtype == GridType.GEOGRAPHIC + + +def test_xarray_backend_gmt_load_grd_grid(): + """ + Ensure that passing engine='gmt' to xarray.load_dataarray works for loading GRD + grids. + """ + da = xr.load_dataarray( + "@earth_relief_20m_holes.grd", engine="gmt", raster_kind="grid" + ) + # Ensure data is in memory. + assert isinstance(da.data, np.ndarray) + npt.assert_allclose(da.min(), -4929.5) + assert da.sizes == {"lat": 31, "lon": 31} + assert da.dtype == "float32" + assert da.gmt.registration == GridRegistration.GRIDLINE + assert da.gmt.gtype == GridType.GEOGRAPHIC + + +def test_xarray_backend_gmt_read_invalid_kind(): + """ + Check that xarray.open_dataarray(..., engine="gmt") fails with missing or incorrect + 'raster_kind'. + """ + with pytest.raises( + TypeError, + match=re.escape( + "GMTBackendEntrypoint.open_dataset() missing 1 required keyword-only argument: 'raster_kind'" + ), + ): + xr.open_dataarray("nokind.nc", engine="gmt") + + with pytest.raises(GMTInvalidInput): + xr.open_dataarray( + filename_or_obj="invalid.tif", engine="gmt", raster_kind="invalid" + ) diff --git a/pygmt/xarray/__init__.py b/pygmt/xarray/__init__.py new file mode 100644 index 00000000000..758b58b35dd --- /dev/null +++ b/pygmt/xarray/__init__.py @@ -0,0 +1,5 @@ +""" +PyGMT integration with Xarray accessors and backends. +""" + +from pygmt.xarray.backend import GMTBackendEntrypoint diff --git a/pygmt/xarray/backend.py b/pygmt/xarray/backend.py new file mode 100644 index 00000000000..a95e98983db --- /dev/null +++ b/pygmt/xarray/backend.py @@ -0,0 +1,119 @@ +""" +An xarray backend for reading raster grid/image files using the 'gmt' engine. +""" + +from typing import Literal + +import xarray as xr +from pygmt._typing import PathLike +from pygmt.clib import Session +from pygmt.exceptions import GMTInvalidInput +from pygmt.helpers import build_arg_list +from pygmt.src.which import which +from xarray.backends import BackendEntrypoint + + +class GMTBackendEntrypoint(BackendEntrypoint): + """ + Xarray backend to read raster grid/image files using 'gmt' engine. + + Internally, GMT uses the netCDF C library to read netCDF files, and GDAL for GeoTIFF + and other raster formats. See :gmt-docs:`reference/features.html#grid-file-format` + for more details about supported formats. This GMT engine can also read + :gmt-docs:`GMT remote datasets ` (file names starting + with an `@`) directly, and pre-loads :class:`pygmt.GMTDataArrayAccessor` properties + (in the '.gmt' accessor) for easy access to GMT-specific metadata and features. + + When using :py:func:`xarray.open_dataarray` or :py:func:`xarray.load_dataarray` with + ``engine="gmt"``, the ``raster_kind`` parameter is required and can be either: + + - ``"grid"``: for reading single-band raster grids + - ``"image"``: for reading multi-band raster images + + Examples + -------- + Read a single-band netCDF file using ``raster_kind="grid"`` + + >>> import pygmt + >>> import xarray as xr + >>> + >>> da_grid = xr.open_dataarray( + ... "@static_earth_relief.nc", engine="gmt", raster_kind="grid" + ... ) + >>> da_grid # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS + ... + [112 values with dtype=float32] + Coordinates: + * lat (lat) float64... -23.5 -22.5 -21.5 -20.5 ... -12.5 -11.5 -10.5 + * lon (lon) float64... -54.5 -53.5 -52.5 -51.5 -50.5 -49.5 -48.5 -47.5 + Attributes:... + Conventions: CF-1.7 + title: Produced by grdcut + history: grdcut @earth_relief_01d_p -R-55/-47/-24/-10 -Gstatic_eart... + description: Reduced by Gaussian Cartesian filtering (111.2 km fullwidt... + actual_range: [190. 981.] + long_name: elevation (m) + + Read a multi-band GeoTIFF file using ``raster_kind="image"`` + + >>> da_image = xr.open_dataarray( + ... "@earth_night_01d", engine="gmt", raster_kind="image" + ... ) + >>> da_image # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS + ... + [194400 values with dtype=uint8] + Coordinates: + * y (y) float64... 89.5 88.5 87.5 86.5 ... -86.5 -87.5 -88.5 -89.5 + * x (x) float64... -179.5 -178.5 -177.5 -176.5 ... 177.5 178.5 179.5 + * band (band) uint8... 1 2 3 + Attributes:... + long_name: z + """ + + description = "Open raster (.grd, .nc or .tif) files in Xarray via GMT." + open_dataset_parameters = ("filename_or_obj", "raster_kind") + url = "https://pygmt.org/dev/api/generated/pygmt.GMTBackendEntrypoint.html" + + def open_dataset( # type: ignore[override] + self, + filename_or_obj: PathLike, + *, + drop_variables=None, # noqa: ARG002 + raster_kind: Literal["grid", "image"], + # other backend specific keyword arguments + # `chunks` and `cache` DO NOT go here, they are handled by xarray + ) -> xr.Dataset: + """ + Backend open_dataset method used by Xarray in :py:func:`~xarray.open_dataset`. + + Parameters + ---------- + filename_or_obj + File path to a netCDF (.nc), GeoTIFF (.tif) or other grid/image file format + that can be read by GMT via the netCDF or GDAL C libraries. See also + :gmt-docs:`reference/features.html#grid-file-format`. + raster_kind + Whether to read the file as a "grid" (single-band) or "image" (multi-band). + """ + if raster_kind not in {"grid", "image"}: + msg = f"Invalid raster kind: '{raster_kind}'. Valid values are 'grid' or 'image'." + raise GMTInvalidInput(msg) + + with Session() as lib: + with lib.virtualfile_out(kind=raster_kind) as voutfile: + kwdict = {"T": {"grid": "g", "image": "i"}[raster_kind]} + lib.call_module( + module="read", + args=[filename_or_obj, voutfile, *build_arg_list(kwdict)], + ) + + raster: xr.DataArray = lib.virtualfile_to_raster( + vfname=voutfile, kind=raster_kind + ) + # Add "source" encoding + source = which(fname=filename_or_obj) + raster.encoding["source"] = ( + source[0] if isinstance(source, list) else source + ) + _ = raster.gmt # Load GMTDataArray accessor information + return raster.to_dataset() diff --git a/pyproject.toml b/pyproject.toml index 2ebe59e7f24..dbacb85b392 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,9 @@ all = [ "rioxarray", ] +[project.entry-points."xarray.backends"] +gmt = "pygmt.xarray:GMTBackendEntrypoint" + [project.urls] "Homepage" = "https://www.pygmt.org" "Documentation" = "https://www.pygmt.org"