diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index 5699d4fb00..b36383f8ac 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -1178,11 +1178,13 @@ template enable_if_t::value, T> cast_safe(object &&) { pybind11_fail("Internal error: cast_safe fallback invoked"); } + template -enable_if_t>::value, void> cast_safe(object &&) {} +enable_if_t>::value, void> cast_safe(object &&) {} + template enable_if_t, - std::is_same>>::value, + std::is_same>>::value, T> cast_safe(object &&o) { return pybind11::cast(std::move(o)); diff --git a/tests/test_virtual_functions.cpp b/tests/test_virtual_functions.cpp index 323aa0d22d..462f7adfec 100644 --- a/tests/test_virtual_functions.cpp +++ b/tests/test_virtual_functions.cpp @@ -230,6 +230,35 @@ inline int test_override_cache(std::shared_ptr const // are rather long). void initialize_inherited_virtuals(py::module_ &m); +namespace issue_4117 { // Broken by PR #3861 + +class Animal { +public: + virtual ~Animal() {} + virtual void *go() = 0; +}; + +class PyAnimal : public Animal { +public: + /* Inherit the constructors */ + using Animal::Animal; + /* Trampoline (need one for each virtual function) */ + void *go() override { + PYBIND11_OVERRIDE_PURE(void *, /* Return type */ + Animal, /* Parent class */ + go, /* Name of function in C++ (must match Python name) */ + ); + } +}; + +void bindings(py::module_ m) { + py::class_(m, "Issue4117Animal") + .def(py::init<>()) + .def("go", &Animal::go); +} + +} // namespace issue_4117 + TEST_SUBMODULE(virtual_functions, m) { // test_override py::class_(m, "ExampleVirt") @@ -410,6 +439,8 @@ TEST_SUBMODULE(virtual_functions, m) { .def("func", &test_override_cache_helper::func); m.def("test_override_cache", test_override_cache); + + issue_4117::bindings(m); } // Inheriting virtual methods. We do two versions here: the repeat-everything version and the diff --git a/tests/test_virtual_functions.py b/tests/test_virtual_functions.py index 4d00d3690d..aa9b8e0230 100644 --- a/tests/test_virtual_functions.py +++ b/tests/test_virtual_functions.py @@ -457,3 +457,12 @@ class Test(m.test_override_cache_helper): for _ in range(1500): assert m.test_override_cache(func()) == 42 assert m.test_override_cache(func2()) == 0 + + +def test_issue_4117(): + class Lynx(m.Issue4117Animal): + def go(self): + return self + + obj = Lynx() + assert obj.go() is obj