Skip to content

Commit c83c9f8

Browse files
authored
Merge branch 'refactor/data_kind' into data_kind/vectors-none
2 parents a1e67d3 + 991f688 commit c83c9f8

File tree

6 files changed

+124
-102
lines changed

6 files changed

+124
-102
lines changed

pygmt/_typing.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""
2+
Type aliases for type hints.
3+
"""
4+
5+
from typing import Literal
6+
7+
# Anchor codes
8+
AnchorCode = Literal["TL", "TC", "TR", "ML", "MC", "MR", "BL", "BC", "BR"]

pygmt/clib/conversion.py

Lines changed: 67 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,39 +5,45 @@
55
import ctypes as ctp
66
import warnings
77
from collections.abc import Sequence
8+
from typing import Any
89

910
import numpy as np
11+
import pandas as pd
12+
import xarray as xr
13+
from packaging.version import Version
1014
from pygmt.exceptions import GMTInvalidInput
1115

1216

13-
def dataarray_to_matrix(grid):
17+
def dataarray_to_matrix(
18+
grid: xr.DataArray,
19+
) -> tuple[np.ndarray, list[float], list[float]]:
1420
"""
15-
Transform an xarray.DataArray into a data 2-D array and metadata.
21+
Transform an xarray.DataArray into a 2-D numpy array and metadata.
1622
17-
Use this to extract the underlying numpy array of data and the region and
18-
increment for the grid.
23+
Use this to extract the underlying numpy array of data and the region and increment
24+
for the grid.
1925
20-
Only allows grids with two dimensions and constant grid spacing (GMT
21-
doesn't allow variable grid spacing). If the latitude and/or longitude
22-
increments of the input grid are negative, the output matrix will be
23-
sorted by the DataArray coordinates to yield positive increments.
26+
Only allows grids with two dimensions and constant grid spacings (GMT doesn't allow
27+
variable grid spacings). If the latitude and/or longitude increments of the input
28+
grid are negative, the output matrix will be sorted by the DataArray coordinates to
29+
yield positive increments.
2430
25-
If the underlying data array is not C contiguous, for example if it's a
26-
slice of a larger grid, a copy will need to be generated.
31+
If the underlying data array is not C contiguous, for example, if it's a slice of a
32+
larger grid, a copy will need to be generated.
2733
2834
Parameters
2935
----------
30-
grid : xarray.DataArray
31-
The input grid as a DataArray instance. Information is retrieved from
32-
the coordinate arrays, not from headers.
36+
grid
37+
The input grid as a DataArray instance. Information is retrieved from the
38+
coordinate arrays, not from headers.
3339
3440
Returns
3541
-------
36-
matrix : 2-D array
42+
matrix
3743
The 2-D array of data from the grid.
38-
region : list
44+
region
3945
The West, East, South, North boundaries of the grid.
40-
inc : list
46+
inc
4147
The grid spacing in East-West and North-South, respectively.
4248
4349
Raises
@@ -62,8 +68,8 @@ def dataarray_to_matrix(grid):
6268
(180, 360)
6369
>>> matrix.flags.c_contiguous
6470
True
65-
>>> # Using a slice of the grid, the matrix will be copied to guarantee
66-
>>> # that it's C-contiguous in memory. The increment should be unchanged.
71+
>>> # Using a slice of the grid, the matrix will be copied to guarantee that it's
72+
>>> # C-contiguous in memory. The increment should be unchanged.
6773
>>> matrix, region, inc = dataarray_to_matrix(grid[10:41, 30:101])
6874
>>> matrix.flags.c_contiguous
6975
True
@@ -73,7 +79,7 @@ def dataarray_to_matrix(grid):
7379
[-150.0, -79.0, -80.0, -49.0]
7480
>>> print(inc)
7581
[1.0, 1.0]
76-
>>> # but not if only taking every other grid point.
82+
>>> # The increment should change accordingly if taking every other grid point.
7783
>>> matrix, region, inc = dataarray_to_matrix(grid[10:41:2, 30:101:2])
7884
>>> matrix.flags.c_contiguous
7985
True
@@ -85,21 +91,19 @@ def dataarray_to_matrix(grid):
8591
[2.0, 2.0]
8692
"""
8793
if len(grid.dims) != 2:
88-
raise GMTInvalidInput(
89-
f"Invalid number of grid dimensions '{len(grid.dims)}'. Must be 2."
90-
)
94+
msg = f"Invalid number of grid dimensions 'len({grid.dims})'. Must be 2."
95+
raise GMTInvalidInput(msg)
96+
9197
# Extract region and inc from the grid
92-
region = []
93-
inc = []
94-
# Reverse the dims because it is rows, columns ordered. In geographic
95-
# grids, this would be North-South, East-West. GMT's region and inc are
96-
# East-West, North-South.
98+
region, inc = [], []
99+
# Reverse the dims because it is rows, columns ordered. In geographic grids, this
100+
# would be North-South, East-West. GMT's region and inc are East-West, North-South.
97101
for dim in grid.dims[::-1]:
98102
coord = grid.coords[dim].to_numpy()
99-
coord_incs = coord[1:] - coord[0:-1]
103+
coord_incs = coord[1:] - coord[:-1]
100104
coord_inc = coord_incs[0]
101105
if not np.allclose(coord_incs, coord_inc):
102-
# calculate the increment if irregular spacing is found
106+
# Calculate the increment if irregular spacing is found.
103107
coord_inc = (coord[-1] - coord[0]) / (coord.size - 1)
104108
msg = (
105109
f"Grid may have irregular spacing in the '{dim}' dimension, "
@@ -108,9 +112,8 @@ def dataarray_to_matrix(grid):
108112
)
109113
warnings.warn(msg, category=RuntimeWarning, stacklevel=2)
110114
if coord_inc == 0:
111-
raise GMTInvalidInput(
112-
f"Grid has a zero increment in the '{dim}' dimension."
113-
)
115+
msg = f"Grid has a zero increment in the '{dim}' dimension."
116+
raise GMTInvalidInput(msg)
114117
region.extend(
115118
[
116119
coord.min() - coord_inc / 2 * grid.gmt.registration,
@@ -129,26 +132,25 @@ def dataarray_to_matrix(grid):
129132
return matrix, region, inc
130133

131134

132-
def vectors_to_arrays(vectors):
135+
def vectors_to_arrays(vectors: Sequence[Any]) -> list[np.ndarray]:
133136
"""
134-
Convert 1-D vectors (lists, arrays, or pandas.Series) to C contiguous 1-D arrays.
137+
Convert 1-D vectors (scalars, lists, or array-like) to C contiguous 1-D arrays.
135138
136-
Arrays must be in C contiguous order for us to pass their memory pointers
137-
to GMT. If any are not, convert them to C order (which requires copying the
138-
memory). This usually happens when vectors are columns of a 2-D array or
139-
have been sliced.
139+
Arrays must be in C contiguous order for us to pass their memory pointers to GMT.
140+
If any are not, convert them to C order (which requires copying the memory). This
141+
usually happens when vectors are columns of a 2-D array or have been sliced.
140142
141-
If a vector is a list or pandas.Series, get the underlying numpy array.
143+
The returned arrays are guaranteed to be C contiguous and at least 1-D.
142144
143145
Parameters
144146
----------
145-
vectors : list of lists, 1-D arrays, or pandas.Series
147+
vectors
146148
The vectors that must be converted.
147149
148150
Returns
149151
-------
150-
arrays : list of 1-D arrays
151-
The converted numpy arrays
152+
arrays
153+
List of converted numpy arrays.
152154
153155
Examples
154156
--------
@@ -178,6 +180,10 @@ def vectors_to_arrays(vectors):
178180
>>> [i.ndim for i in data] # Check that they are 1-D arrays
179181
[1, 1, 1]
180182
183+
>>> series = pd.Series(data=[0, 4, pd.NA, 8, 6], dtype=pd.Int32Dtype())
184+
>>> vectors_to_arrays([series])
185+
[array([ 0., 4., nan, 8., 6.])]
186+
181187
>>> import datetime
182188
>>> import pytest
183189
>>> pa = pytest.importorskip("pyarrow")
@@ -205,8 +211,20 @@ def vectors_to_arrays(vectors):
205211
}
206212
arrays = []
207213
for vector in vectors:
208-
vec_dtype = str(getattr(vector, "dtype", ""))
209-
arrays.append(np.ascontiguousarray(vector, dtype=dtypes.get(vec_dtype)))
214+
if (
215+
hasattr(vector, "isna")
216+
and vector.isna().any()
217+
and Version(pd.__version__) < Version("2.2")
218+
):
219+
# Workaround for dealing with pd.NA with pandas < 2.2.
220+
# Bug report at: https://github.com/GenericMappingTools/pygmt/issues/2844
221+
# Following SPEC0, pandas 2.1 will be dropped in 2025 Q3, so it's likely
222+
# we can remove the workaround in PyGMT v0.17.0.
223+
array = np.ascontiguousarray(vector.astype(float))
224+
else:
225+
vec_dtype = str(getattr(vector, "dtype", ""))
226+
array = np.ascontiguousarray(vector, dtype=dtypes.get(vec_dtype))
227+
arrays.append(array)
210228
return arrays
211229

212230

@@ -289,16 +307,15 @@ def strings_to_ctypes_array(strings: Sequence[str]) -> ctp.Array:
289307
return (ctp.c_char_p * len(strings))(*[s.encode() for s in strings])
290308

291309

292-
def array_to_datetime(array):
310+
def array_to_datetime(array: Sequence[Any]) -> np.ndarray:
293311
"""
294312
Convert a 1-D datetime array from various types into numpy.datetime64.
295313
296-
If the input array is not in legal datetime formats, raise a ValueError
297-
exception.
314+
If the input array is not in legal datetime formats, raise a ValueError exception.
298315
299316
Parameters
300317
----------
301-
array : list or 1-D array
318+
array
302319
The input datetime array in various formats.
303320
304321
Supported types:
@@ -310,7 +327,8 @@ def array_to_datetime(array):
310327
311328
Returns
312329
-------
313-
array : 1-D datetime array in numpy.datetime64
330+
array
331+
1-D datetime array in numpy.datetime64.
314332
315333
Raises
316334
------

pygmt/clib/session.py

Lines changed: 37 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -105,33 +105,28 @@ class Session:
105105
"""
106106
A GMT API session where most operations involving the C API happen.
107107
108-
Works as a context manager (for use in a ``with`` block) to create a GMT C
109-
API session and destroy it in the end to clean up memory.
108+
Works as a context manager (for use in a ``with`` block) to create a GMT C API
109+
session and destroy it in the end to clean up memory.
110110
111-
Functions of the shared library are exposed as methods of this class. Most
112-
methods MUST be used with an open session (inside a ``with`` block). If
113-
creating GMT data structures to communicate data, put that code inside the
114-
same ``with`` block as the API calls that will use the data.
111+
Functions of the shared library are exposed as methods of this class. Most methods
112+
MUST be used with an open session (inside a ``with`` block). If creating GMT data
113+
structures to communicate data, put that code inside the same ``with`` block as the
114+
API calls that will use the data.
115115
116-
By default, will let :mod:`ctypes` try to find the GMT shared library
117-
(``libgmt``). If the environment variable :term:`GMT_LIBRARY_PATH` is set, will
118-
look for the shared library in the directory specified by it.
116+
By default, will let :mod:`ctypes` try to find the GMT shared library (``libgmt``).
117+
If the environment variable :term:`GMT_LIBRARY_PATH` is set, will look for the
118+
shared library in the directory specified by it.
119119
120-
A ``GMTVersionError`` exception will be raised if the GMT shared library
121-
reports a version older than the required minimum GMT version.
122-
123-
The ``session_pointer`` attribute holds a ctypes pointer to the currently
124-
open session.
120+
The ``session_pointer`` attribute holds a ctypes pointer to the currently open
121+
session.
125122
126123
Raises
127124
------
128125
GMTCLibNotFoundError
129-
If there was any problem loading the library (couldn't find it or
130-
couldn't access the functions).
126+
If there was any problem loading the library (couldn't find it or couldn't
127+
access the functions).
131128
GMTCLibNoSessionError
132-
If you try to call a method outside of a 'with' block.
133-
GMTVersionError
134-
If the minimum required version of GMT is not found.
129+
If you try to call a method outside of a ``with`` block.
135130
136131
Examples
137132
--------
@@ -141,45 +136,44 @@ class Session:
141136
>>> grid = load_static_earth_relief()
142137
>>> type(grid)
143138
<class 'xarray.core.dataarray.DataArray'>
144-
>>> # Create a session and destroy it automatically when exiting the "with"
145-
>>> # block.
146-
>>> with Session() as ses:
139+
>>> # Create a session and destroy it automatically when exiting the "with" block.
140+
>>> with Session() as lib:
147141
... # Create a virtual file and link to the memory block of the grid.
148-
... with ses.virtualfile_from_grid(grid) as fin:
142+
... with lib.virtualfile_from_grid(grid) as fin:
149143
... # Create a temp file to use as output.
150144
... with GMTTempFile() as fout:
151-
... # Call the grdinfo module with the virtual file as input
152-
... # and the temp file as output.
153-
... ses.call_module("grdinfo", [fin, "-C", f"->{fout.name}"])
145+
... # Call the grdinfo module with the virtual file as input and the
146+
... # temp file as output.
147+
... lib.call_module("grdinfo", [fin, "-C", f"->{fout.name}"])
154148
... # Read the contents of the temp file before it's deleted.
155149
... print(fout.read().strip())
156150
-55 -47 -24 -10 190 981 1 1 8 14 1 1
157151
"""
158152

159153
@property
160-
def session_pointer(self):
154+
def session_pointer(self) -> ctp.c_void_p:
161155
"""
162156
The :class:`ctypes.c_void_p` pointer to the current open GMT session.
163157
164158
Raises
165159
------
166160
GMTCLibNoSessionError
167-
If trying to access without a currently open GMT session (i.e.,
168-
outside of the context manager).
161+
If trying to access without a currently open GMT session (i.e., outside of
162+
the context manager).
169163
"""
170164
if not hasattr(self, "_session_pointer") or self._session_pointer is None:
171165
raise GMTCLibNoSessionError("No currently open GMT API session.")
172166
return self._session_pointer
173167

174168
@session_pointer.setter
175-
def session_pointer(self, session):
169+
def session_pointer(self, session: ctp.c_void_p):
176170
"""
177171
Set the session void pointer.
178172
"""
179173
self._session_pointer = session
180174

181175
@property
182-
def info(self):
176+
def info(self) -> dict[str, str]:
183177
"""
184178
Dictionary with the GMT version and default paths and parameters.
185179
"""
@@ -629,31 +623,29 @@ def call_module(self, module: str, args: str | list[str]):
629623

630624
# 'args' can be (1) a single string or (2) a list of strings.
631625
argv: bytes | ctp.Array[ctp.c_char_p] | None
632-
if isinstance(args, str):
633-
# 'args' is a single string that contains whitespace-separated arguments.
634-
# In this way, we need to correctly handle option arguments that contain
635-
# whitespaces or quotation marks. It's used in PyGMT <= v0.11.0 but is no
636-
# longer recommended.
637-
mode = self["GMT_MODULE_CMD"]
638-
argv = args.encode()
639-
elif isinstance(args, list):
626+
if isinstance(args, list):
640627
# 'args' is a list of strings and each string contains a module argument.
641628
# In this way, GMT can correctly handle option arguments with whitespaces or
642629
# quotation marks. This is the preferred way to pass arguments to the GMT
643630
# API and is used for PyGMT >= v0.12.0.
644631
mode = len(args) # 'mode' is the number of arguments.
645632
# Pass a null pointer if no arguments are specified.
646633
argv = strings_to_ctypes_array(args) if mode != 0 else None
634+
elif isinstance(args, str):
635+
# 'args' is a single string that contains whitespace-separated arguments.
636+
# In this way, we need to correctly handle option arguments that contain
637+
# whitespaces or quotation marks. It's used in PyGMT <= v0.11.0 but is no
638+
# longer recommended.
639+
mode = self["GMT_MODULE_CMD"]
640+
argv = args.encode()
647641
else:
648-
raise GMTInvalidInput(
649-
"'args' must be either a string or a list of strings."
650-
)
642+
msg = "'args' must either be a list of strings (recommended) or a string."
643+
raise GMTInvalidInput(msg)
651644

652645
status = c_call_module(self.session_pointer, module.encode(), mode, argv)
653646
if status != 0:
654-
raise GMTCLibError(
655-
f"Module '{module}' failed with status code {status}:\n{self._error_message}"
656-
)
647+
msg = f"Module '{module}' failed with status code {status}:\n{self._error_message}"
648+
raise GMTCLibError(msg)
657649

658650
def create_data(
659651
self,

pygmt/helpers/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import sys
1212
import time
1313
import webbrowser
14-
from collections.abc import Iterable, Sequence
14+
from collections.abc import Iterable, Mapping, Sequence
1515
from typing import Any, Literal
1616

1717
import numpy as np
@@ -406,7 +406,7 @@ def non_ascii_to_octal(
406406

407407
def build_arg_list( # noqa: PLR0912
408408
kwdict: dict[str, Any],
409-
confdict: dict[str, str] | None = None,
409+
confdict: Mapping[str, Any] | None = None,
410410
infile: str | pathlib.PurePath | Sequence[str | pathlib.PurePath] | None = None,
411411
outfile: str | pathlib.PurePath | None = None,
412412
) -> list[str]:

0 commit comments

Comments
 (0)