Skip to content

Commit d5e39cf

Browse files
committed
Move most of py3 enum impl into the type caster
1 parent c52e140 commit d5e39cf

File tree

2 files changed

+28
-33
lines changed

2 files changed

+28
-33
lines changed

include/pybind11/cast.h

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -890,52 +890,45 @@ template <typename... Tuple> class type_caster<std::tuple<Tuple...>> {
890890
std::tuple<make_caster<Tuple>...> value;
891891
};
892892

893+
template<typename U>
893894
struct py3_enum_info {
894895
handle type = {};
895-
std::unordered_map<long long, handle> values = {};
896+
std::unordered_map<U, handle> values = {};
896897

897898
py3_enum_info() = default;
898899

899900
py3_enum_info(handle type, const dict& values) : type(type) {
900-
for (auto item : values)
901-
this->values[static_cast<long long>(item.second.cast<int>())] = type.attr(item.first);
902-
}
903-
904-
static std::unordered_map<std::type_index, py3_enum_info>& registry() {
905-
static std::unordered_map<std::type_index, py3_enum_info> map = {};
906-
return map;
907-
}
908-
909-
template<typename T>
910-
static void bind(handle type, const dict& values) {
911-
registry()[typeid(T)] = py3_enum_info(type, values);
912-
}
913-
914-
template<typename T>
915-
static const py3_enum_info* get() {
916-
auto it = registry().find(typeid(T));
917-
return it == registry().end() ? nullptr : &it->second;
901+
for (auto item : values) {
902+
this->values[item.second.cast<U>()] = type.attr(item.first);
903+
}
918904
}
919905
};
920906

921907
template<typename T>
922908
struct type_caster<T, enable_if_t<std::is_enum<T>::value>> {
909+
using underlying_type = typename std::underlying_type<T>::type;
910+
923911
private:
924912
using base_caster = type_caster_base<T>;
913+
914+
static std::unique_ptr<py3_enum_info<underlying_type>>& py3_info() {
915+
static std::unique_ptr<py3_enum_info<underlying_type>> info;
916+
return info;
917+
}
918+
925919
base_caster caster;
926-
bool py3 = false;
927920
T value;
928921

929922
public:
930923
template<typename U> using cast_op_type = pybind11::detail::cast_op_type<U>;
931924

932-
operator T*() { return py3 ? &value : static_cast<T*>(caster); }
933-
operator T&() { return py3 ? value : static_cast<T&>(caster); }
925+
operator T*() { return py3_info() ? &value : static_cast<T*>(caster); }
926+
operator T&() { return py3_info() ? value : static_cast<T&>(caster); }
934927

935928
static handle cast(const T& src, return_value_policy rvp, handle parent) {
936-
if (auto info = py3_enum_info::get<T>()) {
937-
auto it = info->values.find(static_cast<long long>(src));
938-
if (it == info->values.end())
929+
if (py3_info()) {
930+
auto it = py3_info()->values.find(static_cast<underlying_type>(src));
931+
if (it == py3_info()->values.end())
939932
return {};
940933
return it->second.inc_ref();
941934
}
@@ -945,20 +938,22 @@ struct type_caster<T, enable_if_t<std::is_enum<T>::value>> {
945938
bool load(handle src, bool convert) {
946939
if (!src)
947940
return false;
948-
if (auto info = py3_enum_info::get<T>()) {
949-
py3 = true;
950-
if (!isinstance(src, info->type))
941+
if (py3_info()) {
942+
if (!isinstance(src, py3_info()->type))
951943
return false;
952-
value = static_cast<T>(src.cast<long long>());
944+
value = static_cast<T>(src.cast<underlying_type>());
953945
return true;
954946
}
955-
py3 = false;
956947
return caster.load(src, convert);
957948
}
958949

959950
static PYBIND11_DESCR name() {
960951
return base_caster::name();
961952
}
953+
954+
static void bind(handle type, const dict& values) {
955+
py3_info().reset(new py3_enum_info<underlying_type>(type, values));
956+
}
962957
};
963958

964959
/// Helper class which abstracts away certain actions. Users can provide specializations for

include/pybind11/pybind11.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,8 +1243,8 @@ class py3_enum {
12431243
public:
12441244
using underlying_type = typename std::underlying_type<T>::type;
12451245

1246-
py3_enum(handle scope, const char* name)
1247-
: name(name),
1246+
py3_enum(handle scope, const char* enum_name)
1247+
: name(enum_name),
12481248
parent(scope),
12491249
ctor(module::import("enum").attr("IntEnum")),
12501250
unique(module::import("enum").attr("unique")) {
@@ -1267,7 +1267,7 @@ class py3_enum {
12671267
void update() {
12681268
object type = unique(ctor(name, entries));
12691269
setattr(parent, name, type);
1270-
detail::py3_enum_info::bind<T>(type, entries);
1270+
detail::type_caster<T>::bind(type, entries);
12711271
}
12721272
};
12731273

0 commit comments

Comments
 (0)