Skip to content

Commit 4302474

Browse files
committed
KeysView/ValuesView/ItemsView using Python types
1 parent eeac2f4 commit 4302474

File tree

3 files changed

+294
-36
lines changed

3 files changed

+294
-36
lines changed

include/pybind11/stl_bind.h

Lines changed: 148 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include "detail/common.h"
1313
#include "detail/type_caster_base.h"
1414
#include "cast.h"
15+
#include "complex.h"
16+
#include "functional.h"
1517
#include "operators.h"
1618

1719
#include <algorithm>
@@ -483,6 +485,138 @@ void vector_buffer(Class_ &cl) {
483485
cl, detail::any_of<std::is_same<Args, buffer_protocol>...>{});
484486
}
485487

488+
// Issue #3986 and #4529: map C++ types to Python types with typing strings
489+
template <typename T, typename SFINAE = void>
490+
struct type_mapper {
491+
using py_type = T;
492+
static std::string py_name() { return detail::type_info_description(typeid(T)); }
493+
};
494+
495+
template <>
496+
struct type_mapper<std::nullptr_t> {
497+
using py_type = pybind11::none;
498+
static std::string py_name() {
499+
constexpr auto descr = detail::make_caster<std::nullptr_t>::name;
500+
return descr.text;
501+
}
502+
};
503+
504+
template <>
505+
struct type_mapper<bool> {
506+
using py_type = pybind11::bool_;
507+
static std::string py_name() {
508+
constexpr auto descr = detail::make_caster<bool>::name;
509+
return descr.text;
510+
}
511+
};
512+
513+
template <typename T>
514+
struct type_mapper<T, enable_if_t<std::is_arithmetic<T>::value && !is_std_char_type<T>::value>> {
515+
using py_type
516+
= conditional_t<std::is_floating_point<T>::value, pybind11::float_, pybind11::int_>;
517+
static std::string py_name() {
518+
constexpr auto descr = detail::make_caster<T>::name;
519+
return descr.text;
520+
}
521+
};
522+
523+
template <typename T>
524+
struct type_mapper<std::complex<T>> {
525+
using py_type = std::complex<typename type_mapper<T>::py_type>;
526+
static std::string py_name() {
527+
constexpr auto descr = detail::make_caster<std::complex<T>>::name;
528+
return descr.text;
529+
}
530+
};
531+
532+
template <typename T>
533+
struct type_mapper<T, enable_if_t<is_std_char_type<T>::value>> {
534+
using py_type = pybind11::str;
535+
static std::string py_name() {
536+
constexpr auto descr = detail::make_caster<T>::name;
537+
return descr.text;
538+
}
539+
};
540+
541+
template <typename T>
542+
struct type_mapper<T, enable_if_t<is_pyobject<T>::value>> {
543+
using py_type = T;
544+
static std::string py_name() {
545+
constexpr auto descr = detail::make_caster<T>::name;
546+
return descr.text;
547+
}
548+
};
549+
550+
template <typename T>
551+
struct type_mapper<std::shared_ptr<T>> : public type_mapper<T> {};
552+
553+
template <typename T, typename Deleter>
554+
struct type_mapper<std::unique_ptr<T, Deleter>> : public type_mapper<T> {};
555+
556+
template <typename CharT, typename Traits, typename Allocator>
557+
struct type_mapper<std::basic_string<CharT, Traits, Allocator>,
558+
enable_if_t<is_std_char_type<CharT>::value>> {
559+
using py_type = pybind11::str;
560+
static std::string py_name() {
561+
constexpr auto descr
562+
= detail::make_caster<std::basic_string<CharT, Traits, Allocator>>::name;
563+
return descr.text;
564+
}
565+
};
566+
567+
#ifdef PYBIND11_HAS_STRING_VIEW
568+
template <typename CharT, typename Traits>
569+
struct type_mapper<std::basic_string_view<CharT, Traits>,
570+
enable_if_t<is_std_char_type<CharT>::value>> {
571+
using py_type = pybind11::str;
572+
static std::string py_name() {
573+
constexpr auto descr = detail::make_caster<std::basic_string_view<CharT, Traits>>::name;
574+
return descr.text;
575+
}
576+
};
577+
#endif
578+
579+
template <typename T1, typename T2>
580+
struct type_mapper<std::pair<T1, T2>> {
581+
using py_type
582+
= std::tuple<typename type_mapper<T1>::py_type, typename type_mapper<T1>::py_type>;
583+
static std::string py_name() {
584+
return "tuple[" + type_mapper<T1>::py_name() + ", " + type_mapper<T2>::py_name() + "]";
585+
}
586+
};
587+
588+
template <typename... Ts>
589+
struct type_mapper<std::tuple<Ts...>> {
590+
using py_type = std::tuple<typename type_mapper<Ts>::py_type...>;
591+
static std::string py_name() {
592+
std::vector<std::string> names = {type_mapper<Ts>::py_name()...};
593+
std::ostringstream s;
594+
s << "tuple[";
595+
for (size_t i = 0; i < names.size(); ++i) {
596+
s << (i != 0 ? ", " : "") << names[i];
597+
}
598+
s << "]";
599+
return s.str();
600+
}
601+
};
602+
603+
template <typename Return, typename... Args>
604+
struct type_mapper<std::function<Return(Args...)>> {
605+
using retval_type = conditional_t<std::is_same<Return, void>::value, std::nullptr_t, Return>;
606+
using py_type = std::function<typename type_mapper<retval_type>::py_type(
607+
typename type_mapper<Args>::py_type...)>;
608+
static std::string py_name() {
609+
std::vector<std::string> names = {type_mapper<Args>::py_name()...};
610+
std::ostringstream s;
611+
s << "Callable[[";
612+
for (size_t i = 0; i < names.size(); ++i) {
613+
s << (i != 0 ? ", " : "") << names[i];
614+
}
615+
s << "], " << type_mapper<retval_type>::py_name() << "]";
616+
return s.str();
617+
}
618+
};
619+
486620
PYBIND11_NAMESPACE_END(detail)
487621

488622
//
@@ -649,8 +783,7 @@ template <typename KeyType>
649783
struct keys_view {
650784
virtual size_t len() = 0;
651785
virtual iterator iter() = 0;
652-
virtual bool contains(const KeyType &k) = 0;
653-
virtual bool contains(const object &k) = 0;
786+
virtual bool contains(const handle &k) = 0;
654787
virtual ~keys_view() = default;
655788
};
656789

@@ -673,8 +806,10 @@ struct KeysViewImpl : public KeysView {
673806
explicit KeysViewImpl(Map &map) : map(map) {}
674807
size_t len() override { return map.size(); }
675808
iterator iter() override { return make_key_iterator(map.begin(), map.end()); }
676-
bool contains(const typename Map::key_type &k) override { return map.find(k) != map.end(); }
677-
bool contains(const object &) override { return false; }
809+
bool contains(const handle &k) override {
810+
return detail::make_caster<typename Map::key_type>().load(k, true)
811+
&& map.find(k.template cast<typename Map::key_type>()) != map.end();
812+
}
678813
Map &map;
679814
};
680815

@@ -702,9 +837,11 @@ class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args &&
702837
using MappedType = typename Map::mapped_type;
703838
using StrippedKeyType = detail::remove_cvref_t<KeyType>;
704839
using StrippedMappedType = detail::remove_cvref_t<MappedType>;
705-
using KeysView = detail::keys_view<StrippedKeyType>;
706-
using ValuesView = detail::values_view<StrippedMappedType>;
707-
using ItemsView = detail::items_view<StrippedKeyType, StrippedMappedType>;
840+
using PyKeyType = typename detail::type_mapper<StrippedKeyType>::py_type;
841+
using PyMappedType = typename detail::type_mapper<StrippedMappedType>::py_type;
842+
using KeysView = detail::keys_view<PyKeyType>;
843+
using ValuesView = detail::values_view<PyMappedType>;
844+
using ItemsView = detail::items_view<PyKeyType, PyMappedType>;
708845
using Class_ = class_<Map, holder_type>;
709846

710847
// If either type is a non-module-local bound type then make the map binding non-local as well;
@@ -718,20 +855,10 @@ class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args &&
718855
}
719856

720857
Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward<Args>(args)...);
721-
static constexpr auto key_type_descr = detail::make_caster<KeyType>::name;
722-
static constexpr auto mapped_type_descr = detail::make_caster<MappedType>::name;
723-
std::string key_type_name(key_type_descr.text), mapped_type_name(mapped_type_descr.text);
724-
725-
// If key type isn't properly wrapped, fall back to C++ names
726-
if (key_type_name == "%") {
727-
key_type_name = detail::type_info_description(typeid(KeyType));
728-
}
729-
// Similarly for value type:
730-
if (mapped_type_name == "%") {
731-
mapped_type_name = detail::type_info_description(typeid(MappedType));
732-
}
858+
std::string key_type_name = detail::type_mapper<StrippedKeyType>::py_name();
859+
std::string mapped_type_name = detail::type_mapper<StrippedMappedType>::py_name();
733860

734-
// Wrap KeysView[KeyType] if it wasn't already wrapped
861+
// Wrap KeysView[PyKeyType] if it wasn't already wrapped
735862
if (!detail::get_type_info(typeid(KeysView))) {
736863
class_<KeysView> keys_view(
737864
scope, ("KeysView[" + key_type_name + "]").c_str(), pybind11::module_local(local));
@@ -741,10 +868,7 @@ class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args &&
741868
keep_alive<0, 1>() /* Essential: keep view alive while iterator exists */
742869
);
743870
keys_view.def("__contains__",
744-
static_cast<bool (KeysView::*)(const KeyType &)>(&KeysView::contains));
745-
// Fallback for when the object is not of the key type
746-
keys_view.def("__contains__",
747-
static_cast<bool (KeysView::*)(const object &)>(&KeysView::contains));
871+
static_cast<bool (KeysView::*)(const handle &)>(&KeysView::contains));
748872
}
749873
// Similarly for ValuesView:
750874
if (!detail::get_type_info(typeid(ValuesView))) {

tests/test_stl_binders.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,35 @@ TEST_SUBMODULE(stl_binders, m) {
187187
py::bind_map<std::map<std::string, double>>(m, "MapStringDouble");
188188
py::bind_map<std::unordered_map<std::string, double>>(m, "UnorderedMapStringDouble");
189189

190+
// test_map_view_types
191+
py::bind_map<std::map<std::string, float>>(m, "MapStringFloat");
192+
py::bind_map<std::unordered_map<std::string, float>>(m, "UnorderedMapStringFloat");
193+
py::bind_map<std::map<int16_t, double>>(m, "MapInt16Double");
194+
py::bind_map<std::map<int32_t, double>>(m, "MapInt32Double");
195+
py::bind_map<std::map<int64_t, double>>(m, "MapInt64Double");
196+
py::bind_map<std::map<uint64_t, double>>(m, "MapUInt64Double");
197+
py::bind_map<std::map<std::pair<short, short>, double>>(m, "MapPairShortShortDouble");
198+
py::bind_map<std::map<std::pair<short, long>, std::complex<float>>>(
199+
m, "MapPairShortLongComplexFloat");
200+
py::bind_map<std::map<std::pair<long, short>, std::complex<double>>>(
201+
m, "MapPairLongShortComplexDouble");
202+
py::bind_map<std::map<std::tuple<long, long>, std::complex<double>>>(
203+
m, "MapTupleLongLongComplexDouble");
204+
py::bind_map<std::map<char, std::function<float(int, float)>>>(m,
205+
"MapCharFunctionFloatIntFloat");
206+
py::bind_map<std::map<std::string, std::function<double(long, double)>>>(
207+
m, "MapStringFunctionDoubleLongDouble");
208+
py::bind_map<std::map<std::string, std::function<void(long, double)>>>(
209+
m, "MapStringFunctionVoidLongDouble");
210+
py::bind_map<std::map<std::string, std::nullptr_t>>(m, "MapStringNone");
211+
212+
py::bind_map<std::map<int, std::pair<std::map<int, int>, int>>>(m, "MapIntMapIntIntInt");
213+
py::bind_map<std::map<int, std::pair<std::map<int, int>, long>>>(m, "MapIntMapIntIntLong");
214+
py::bind_map<std::map<int, std::pair<std::map<long, int>, long>>>(m, "MapIntMapLongIntLong");
215+
216+
py::bind_map<std::map<pybind11::int_, int>>(m, "MapPyIntInt");
217+
py::bind_map<std::map<pybind11::int_, pybind11::int_>>(m, "MapPyIntPyInt");
218+
190219
// test_map_string_double_const
191220
py::bind_map<std::map<std::string, double const>>(m, "MapStringDoubleConst");
192221
py::bind_map<std::unordered_map<std::string, double const>>(m,

0 commit comments

Comments
 (0)