Skip to content

Commit 31b7e14

Browse files
authored
bugfix: removing typing and duplicate class_ for KeysView/ValuesView/ItemsView. Fix #4529 (#4985)
* remove typing for KeysView/ValuesView/ItemsView * add tests for map view types
1 parent aec6cc5 commit 31b7e14

File tree

3 files changed

+62
-57
lines changed

3 files changed

+62
-57
lines changed

include/pybind11/stl_bind.h

+25-54
Original file line numberDiff line numberDiff line change
@@ -645,49 +645,50 @@ auto map_if_insertion_operator(Class_ &cl, std::string const &name)
645645
"Return the canonical string representation of this map.");
646646
}
647647

648-
template <typename KeyType>
649648
struct keys_view {
650649
virtual size_t len() = 0;
651650
virtual iterator iter() = 0;
652-
virtual bool contains(const KeyType &k) = 0;
653-
virtual bool contains(const object &k) = 0;
651+
virtual bool contains(const handle &k) = 0;
654652
virtual ~keys_view() = default;
655653
};
656654

657-
template <typename MappedType>
658655
struct values_view {
659656
virtual size_t len() = 0;
660657
virtual iterator iter() = 0;
661658
virtual ~values_view() = default;
662659
};
663660

664-
template <typename KeyType, typename MappedType>
665661
struct items_view {
666662
virtual size_t len() = 0;
667663
virtual iterator iter() = 0;
668664
virtual ~items_view() = default;
669665
};
670666

671-
template <typename Map, typename KeysView>
672-
struct KeysViewImpl : public KeysView {
667+
template <typename Map>
668+
struct KeysViewImpl : public detail::keys_view {
673669
explicit KeysViewImpl(Map &map) : map(map) {}
674670
size_t len() override { return map.size(); }
675671
iterator iter() override { return make_key_iterator(map.begin(), map.end()); }
676-
bool contains(const typename Map::key_type &k) override { return map.find(k) != map.end(); }
677-
bool contains(const object &) override { return false; }
672+
bool contains(const handle &k) override {
673+
try {
674+
return map.find(k.template cast<typename Map::key_type>()) != map.end();
675+
} catch (const cast_error &) {
676+
return false;
677+
}
678+
}
678679
Map &map;
679680
};
680681

681-
template <typename Map, typename ValuesView>
682-
struct ValuesViewImpl : public ValuesView {
682+
template <typename Map>
683+
struct ValuesViewImpl : public detail::values_view {
683684
explicit ValuesViewImpl(Map &map) : map(map) {}
684685
size_t len() override { return map.size(); }
685686
iterator iter() override { return make_value_iterator(map.begin(), map.end()); }
686687
Map &map;
687688
};
688689

689-
template <typename Map, typename ItemsView>
690-
struct ItemsViewImpl : public ItemsView {
690+
template <typename Map>
691+
struct ItemsViewImpl : public detail::items_view {
691692
explicit ItemsViewImpl(Map &map) : map(map) {}
692693
size_t len() override { return map.size(); }
693694
iterator iter() override { return make_iterator(map.begin(), map.end()); }
@@ -700,11 +701,9 @@ template <typename Map, typename holder_type = std::unique_ptr<Map>, typename...
700701
class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args &&...args) {
701702
using KeyType = typename Map::key_type;
702703
using MappedType = typename Map::mapped_type;
703-
using StrippedKeyType = detail::remove_cvref_t<KeyType>;
704-
using StrippedMappedType = detail::remove_cvref_t<MappedType>;
705-
using KeysView = detail::keys_view<StrippedKeyType>;
706-
using ValuesView = detail::values_view<StrippedMappedType>;
707-
using ItemsView = detail::items_view<StrippedKeyType, StrippedMappedType>;
704+
using KeysView = detail::keys_view;
705+
using ValuesView = detail::values_view;
706+
using ItemsView = detail::items_view;
708707
using Class_ = class_<Map, holder_type>;
709708

710709
// If either type is a non-module-local bound type then make the map binding non-local as well;
@@ -718,39 +717,20 @@ class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args &&
718717
}
719718

720719
Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward<Args>(args)...);
721-
static constexpr auto key_type_descr = detail::make_caster<KeyType>::name;
722-
static constexpr auto mapped_type_descr = detail::make_caster<MappedType>::name;
723-
std::string key_type_name(key_type_descr.text), mapped_type_name(mapped_type_descr.text);
724720

725-
// If key type isn't properly wrapped, fall back to C++ names
726-
if (key_type_name == "%") {
727-
key_type_name = detail::type_info_description(typeid(KeyType));
728-
}
729-
// Similarly for value type:
730-
if (mapped_type_name == "%") {
731-
mapped_type_name = detail::type_info_description(typeid(MappedType));
732-
}
733-
734-
// Wrap KeysView[KeyType] if it wasn't already wrapped
721+
// Wrap KeysView if it wasn't already wrapped
735722
if (!detail::get_type_info(typeid(KeysView))) {
736-
class_<KeysView> keys_view(
737-
scope, ("KeysView[" + key_type_name + "]").c_str(), pybind11::module_local(local));
723+
class_<KeysView> keys_view(scope, "KeysView", pybind11::module_local(local));
738724
keys_view.def("__len__", &KeysView::len);
739725
keys_view.def("__iter__",
740726
&KeysView::iter,
741727
keep_alive<0, 1>() /* Essential: keep view alive while iterator exists */
742728
);
743-
keys_view.def("__contains__",
744-
static_cast<bool (KeysView::*)(const KeyType &)>(&KeysView::contains));
745-
// Fallback for when the object is not of the key type
746-
keys_view.def("__contains__",
747-
static_cast<bool (KeysView::*)(const object &)>(&KeysView::contains));
729+
keys_view.def("__contains__", &KeysView::contains);
748730
}
749731
// Similarly for ValuesView:
750732
if (!detail::get_type_info(typeid(ValuesView))) {
751-
class_<ValuesView> values_view(scope,
752-
("ValuesView[" + mapped_type_name + "]").c_str(),
753-
pybind11::module_local(local));
733+
class_<ValuesView> values_view(scope, "ValuesView", pybind11::module_local(local));
754734
values_view.def("__len__", &ValuesView::len);
755735
values_view.def("__iter__",
756736
&ValuesView::iter,
@@ -759,10 +739,7 @@ class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args &&
759739
}
760740
// Similarly for ItemsView:
761741
if (!detail::get_type_info(typeid(ItemsView))) {
762-
class_<ItemsView> items_view(
763-
scope,
764-
("ItemsView[" + key_type_name + ", ").append(mapped_type_name + "]").c_str(),
765-
pybind11::module_local(local));
742+
class_<ItemsView> items_view(scope, "ItemsView", pybind11::module_local(local));
766743
items_view.def("__len__", &ItemsView::len);
767744
items_view.def("__iter__",
768745
&ItemsView::iter,
@@ -788,25 +765,19 @@ class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args &&
788765

789766
cl.def(
790767
"keys",
791-
[](Map &m) {
792-
return std::unique_ptr<KeysView>(new detail::KeysViewImpl<Map, KeysView>(m));
793-
},
768+
[](Map &m) { return std::unique_ptr<KeysView>(new detail::KeysViewImpl<Map>(m)); },
794769
keep_alive<0, 1>() /* Essential: keep map alive while view exists */
795770
);
796771

797772
cl.def(
798773
"values",
799-
[](Map &m) {
800-
return std::unique_ptr<ValuesView>(new detail::ValuesViewImpl<Map, ValuesView>(m));
801-
},
774+
[](Map &m) { return std::unique_ptr<ValuesView>(new detail::ValuesViewImpl<Map>(m)); },
802775
keep_alive<0, 1>() /* Essential: keep map alive while view exists */
803776
);
804777

805778
cl.def(
806779
"items",
807-
[](Map &m) {
808-
return std::unique_ptr<ItemsView>(new detail::ItemsViewImpl<Map, ItemsView>(m));
809-
},
780+
[](Map &m) { return std::unique_ptr<ItemsView>(new detail::ItemsViewImpl<Map>(m)); },
810781
keep_alive<0, 1>() /* Essential: keep map alive while view exists */
811782
);
812783

tests/test_stl_binders.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,16 @@ TEST_SUBMODULE(stl_binders, m) {
192192
py::bind_map<std::unordered_map<std::string, double const>>(m,
193193
"UnorderedMapStringDoubleConst");
194194

195+
// test_map_view_types
196+
py::bind_map<std::map<std::string, float>>(m, "MapStringFloat");
197+
py::bind_map<std::unordered_map<std::string, float>>(m, "UnorderedMapStringFloat");
198+
199+
py::bind_map<std::map<std::pair<double, int>, int32_t>>(m, "MapPairDoubleIntInt32");
200+
py::bind_map<std::map<std::pair<double, int>, int64_t>>(m, "MapPairDoubleIntInt64");
201+
202+
py::bind_map<std::map<int, py::object>>(m, "MapIntObject");
203+
py::bind_map<std::map<std::string, py::object>>(m, "MapStringObject");
204+
195205
py::class_<E_nc>(m, "ENC").def(py::init<int>()).def_readwrite("value", &E_nc::value);
196206

197207
// test_noncopyable_containers

tests/test_stl_binders.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -317,9 +317,9 @@ def test_map_view_types():
317317
map_string_double_const = m.MapStringDoubleConst()
318318
unordered_map_string_double_const = m.UnorderedMapStringDoubleConst()
319319

320-
assert map_string_double.keys().__class__.__name__ == "KeysView[str]"
321-
assert map_string_double.values().__class__.__name__ == "ValuesView[float]"
322-
assert map_string_double.items().__class__.__name__ == "ItemsView[str, float]"
320+
assert map_string_double.keys().__class__.__name__ == "KeysView"
321+
assert map_string_double.values().__class__.__name__ == "ValuesView"
322+
assert map_string_double.items().__class__.__name__ == "ItemsView"
323323

324324
keys_type = type(map_string_double.keys())
325325
assert type(unordered_map_string_double.keys()) is keys_type
@@ -336,6 +336,30 @@ def test_map_view_types():
336336
assert type(map_string_double_const.items()) is items_type
337337
assert type(unordered_map_string_double_const.items()) is items_type
338338

339+
map_string_float = m.MapStringFloat()
340+
unordered_map_string_float = m.UnorderedMapStringFloat()
341+
342+
assert type(map_string_float.keys()) is keys_type
343+
assert type(unordered_map_string_float.keys()) is keys_type
344+
assert type(map_string_float.values()) is values_type
345+
assert type(unordered_map_string_float.values()) is values_type
346+
assert type(map_string_float.items()) is items_type
347+
assert type(unordered_map_string_float.items()) is items_type
348+
349+
map_pair_double_int_int32 = m.MapPairDoubleIntInt32()
350+
map_pair_double_int_int64 = m.MapPairDoubleIntInt64()
351+
352+
assert type(map_pair_double_int_int32.values()) is values_type
353+
assert type(map_pair_double_int_int64.values()) is values_type
354+
355+
map_int_object = m.MapIntObject()
356+
map_string_object = m.MapStringObject()
357+
358+
assert type(map_int_object.keys()) is keys_type
359+
assert type(map_string_object.keys()) is keys_type
360+
assert type(map_int_object.items()) is items_type
361+
assert type(map_string_object.items()) is items_type
362+
339363

340364
def test_recursive_vector():
341365
recursive_vector = m.RecursiveVector()

0 commit comments

Comments
 (0)