Skip to content

Commit 4c7697d

Browse files
jhalerwgk
andauthored
Add const T to docstring generation. (#3020)
* Add const T to docstring generation. * Change order. * See if existing test triggers for a const type. * Add tests. * Fix test. * Remove experiment. * Reformat. * More tests, checks run. * Adding `test_fmt_desc_` prefix to new test functions. * Using pytest.mark.parametrize to 1. condense test; 2. exercise all functions even if one fails; 3. be less platform-specific (e.g. C++ float is not necessarily float32). Co-authored-by: Ralf W. Grosse-Kunstleve <[email protected]>
1 parent e25b150 commit 4c7697d

File tree

3 files changed

+26
-2
lines changed

3 files changed

+26
-2
lines changed

include/pybind11/numpy.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,15 +1029,20 @@ struct npy_format_descriptor_name<T, enable_if_t<std::is_integral<T>::value>> {
10291029

10301030
template <typename T>
10311031
struct npy_format_descriptor_name<T, enable_if_t<std::is_floating_point<T>::value>> {
1032-
static constexpr auto name = _<std::is_same<T, float>::value || std::is_same<T, double>::value>(
1032+
static constexpr auto name = _<std::is_same<T, float>::value
1033+
|| std::is_same<T, const float>::value
1034+
|| std::is_same<T, double>::value
1035+
|| std::is_same<T, const double>::value>(
10331036
_("numpy.float") + _<sizeof(T)*8>(), _("numpy.longdouble")
10341037
);
10351038
};
10361039

10371040
template <typename T>
10381041
struct npy_format_descriptor_name<T, enable_if_t<is_complex<T>::value>> {
10391042
static constexpr auto name = _<std::is_same<typename T::value_type, float>::value
1040-
|| std::is_same<typename T::value_type, double>::value>(
1043+
|| std::is_same<typename T::value_type, const float>::value
1044+
|| std::is_same<typename T::value_type, double>::value
1045+
|| std::is_same<typename T::value_type, const double>::value>(
10411046
_("numpy.complex") + _<sizeof(typename T::value_type)*16>(), _("numpy.longcomplex")
10421047
);
10431048
};

tests/test_numpy_array.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,4 +437,10 @@ TEST_SUBMODULE(numpy_array, sm) {
437437
sm.def("accept_double_f_style_forcecast_noconvert",
438438
[](py::array_t<double, py::array::forcecast | py::array::f_style>) {},
439439
"a"_a.noconvert());
440+
441+
// Check that types returns correct npy format descriptor
442+
sm.def("test_fmt_desc_float", [](py::array_t<float>) {});
443+
sm.def("test_fmt_desc_double", [](py::array_t<double>) {});
444+
sm.def("test_fmt_desc_const_float", [](py::array_t<const float>) {});
445+
sm.def("test_fmt_desc_const_double", [](py::array_t<const double>) {});
440446
}

tests/test_numpy_array.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,19 @@ def test_index_using_ellipsis():
482482
assert a.shape == (6,)
483483

484484

485+
@pytest.mark.parametrize(
486+
"test_func",
487+
[
488+
m.test_fmt_desc_float,
489+
m.test_fmt_desc_double,
490+
m.test_fmt_desc_const_float,
491+
m.test_fmt_desc_const_double,
492+
],
493+
)
494+
def test_format_descriptors_for_floating_point_types(test_func):
495+
assert "numpy.ndarray[numpy.float" in test_func.__doc__
496+
497+
485498
@pytest.mark.parametrize("forcecast", [False, True])
486499
@pytest.mark.parametrize("contiguity", [None, "C", "F"])
487500
@pytest.mark.parametrize("noconvert", [False, True])

0 commit comments

Comments
 (0)