Skip to content

Commit e5ecee9

Browse files
authored
Improve the data type checking for 2-D arrays passed to the GMT C API (#3563)
1 parent 8eb2b4f commit e5ecee9

File tree

1 file changed

+33
-32
lines changed

1 file changed

+33
-32
lines changed

pygmt/clib/session.py

+33-32
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@
8282

8383
REGISTRATIONS = ["GMT_GRID_NODE_REG", "GMT_GRID_PIXEL_REG"]
8484

85-
DTYPES = {
85+
# Dictionary for mapping numpy dtypes to GMT data types.
86+
DTYPES_NUMERIC = {
8687
np.int8: "GMT_CHAR",
8788
np.int16: "GMT_SHORT",
8889
np.int32: "GMT_INT",
@@ -93,10 +94,14 @@
9394
np.uint64: "GMT_ULONG",
9495
np.float32: "GMT_FLOAT",
9596
np.float64: "GMT_DOUBLE",
97+
np.timedelta64: "GMT_LONG",
98+
}
99+
DTYPES_TEXT = {
96100
np.str_: "GMT_TEXT",
97101
np.datetime64: "GMT_DATETIME",
98-
np.timedelta64: "GMT_LONG",
99102
}
103+
DTYPES = DTYPES_NUMERIC | DTYPES_TEXT
104+
100105
# Dictionary for storing the values of GMT constants.
101106
GMT_CONSTANTS = {}
102107

@@ -879,63 +884,59 @@ def _parse_constant(
879884
integer_value = sum(self[part] for part in parts)
880885
return integer_value
881886

882-
def _check_dtype_and_dim(self, array, ndim):
887+
def _check_dtype_and_dim(self, array: np.ndarray, ndim: int) -> int:
883888
"""
884889
Check that a numpy array has the given number of dimensions and is a valid data
885890
type.
886891
887892
Parameters
888893
----------
889-
array : numpy.ndarray
894+
array
890895
The array to be tested.
891-
ndim : int
896+
ndim
892897
The desired number of array dimensions.
893898
894899
Returns
895900
-------
896-
gmt_type : int
901+
gmt_type
897902
The GMT constant value representing this data type.
898903
899904
Raises
900905
------
901906
GMTInvalidInput
902-
If the array has the wrong number of dimensions or
903-
is an unsupported data type.
907+
If the array has the wrong number of dimensions or is an unsupported data
908+
type.
904909
905910
Examples
906911
--------
907-
908912
>>> import numpy as np
909913
>>> data = np.array([1, 2, 3], dtype="float64")
910-
>>> with Session() as ses:
911-
... gmttype = ses._check_dtype_and_dim(data, ndim=1)
912-
... gmttype == ses["GMT_DOUBLE"]
914+
>>> with Session() as lib:
915+
... gmttype = lib._check_dtype_and_dim(data, ndim=1)
916+
... gmttype == lib["GMT_DOUBLE"]
913917
True
914918
>>> data = np.ones((5, 2), dtype="float32")
915-
>>> with Session() as ses:
916-
... gmttype = ses._check_dtype_and_dim(data, ndim=2)
917-
... gmttype == ses["GMT_FLOAT"]
919+
>>> with Session() as lib:
920+
... gmttype = lib._check_dtype_and_dim(data, ndim=2)
921+
... gmttype == lib["GMT_FLOAT"]
918922
True
919923
"""
920-
# Check that the array has the given number of dimensions
924+
# Check that the array has the given number of dimensions.
921925
if array.ndim != ndim:
922-
raise GMTInvalidInput(
923-
f"Expected a numpy {ndim}-D array, got {array.ndim}-D."
924-
)
926+
msg = f"Expected a numpy {ndim}-D array, got {array.ndim}-D."
927+
raise GMTInvalidInput(msg)
925928

926-
# Check that the array has a valid/known data type
927-
if array.dtype.type not in DTYPES:
928-
try:
929-
if array.dtype.type is np.object_:
930-
# Try to convert unknown object type to np.datetime64
931-
array = array_to_datetime(array)
932-
else:
933-
raise ValueError
934-
except ValueError as e:
935-
raise GMTInvalidInput(
936-
f"Unsupported numpy data type '{array.dtype.type}'."
937-
) from e
938-
return self[DTYPES[array.dtype.type]]
929+
# For 1-D arrays, try to convert unknown object type to np.datetime64.
930+
if ndim == 1 and array.dtype.type is np.object_:
931+
with contextlib.suppress(ValueError):
932+
array = array_to_datetime(array)
933+
934+
# 1-D arrays can be numeric or text, 2-D arrays can only be numeric.
935+
valid_dtypes = DTYPES if ndim == 1 else DTYPES_NUMERIC
936+
if (dtype := array.dtype.type) not in valid_dtypes:
937+
msg = f"Unsupported numpy data type '{dtype}'."
938+
raise GMTInvalidInput(msg)
939+
return self[DTYPES[dtype]]
939940

940941
def put_vector(self, dataset: ctp.c_void_p, column: int, vector: np.ndarray):
941942
r"""

0 commit comments

Comments
 (0)