@@ -26,6 +26,16 @@ limitations under the License.
26
26
#include < pybind11/pybind11.h>
27
27
#include < pybind11/stl.h>
28
28
29
+ #if defined(PYBIND11_HAS_NATIVE_ENUM) || \
30
+ (defined(PYBIND11_INTERNALS_VERSION) && PYBIND11_INTERNALS_VERSION >= 8 )
31
+ #ifndef PYBIND11_HAS_NATIVE_ENUM
32
+ #define PYBIND11_HAS_NATIVE_ENUM
33
+ #endif
34
+ #include < pybind11/native_enum.h>
35
+ #else
36
+ #undef PYBIND11_HAS_NATIVE_ENUM
37
+ #endif
38
+
29
39
namespace optree {
30
40
31
41
py::module_ GetCxxModule (const std::optional<py::module_>& module) {
@@ -205,6 +215,22 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
205
215
#define def_method_pos_only (...) def(__VA_ARGS__)
206
216
#endif
207
217
218
+ #ifdef PYBIND11_HAS_NATIVE_ENUM
219
+ py::native_enum<PyTreeKind>(mod, " PyTreeKind" , " enum.IntEnum" , " The kind of a pytree node." )
220
+ .value (" CUSTOM" , PyTreeKind::Custom, " A custom type." )
221
+ .value (" LEAF" , PyTreeKind::Leaf, " A opaque leaf node." )
222
+ .value (" NONE" , PyTreeKind::None, " None." )
223
+ .value (" TUPLE" , PyTreeKind::Tuple, " A tuple." )
224
+ .value (" LIST" , PyTreeKind::List, " A list." )
225
+ .value (" DICT" , PyTreeKind::Dict, " A dict." )
226
+ .value (" NAMEDTUPLE" , PyTreeKind::NamedTuple, " A collections.namedtuple." )
227
+ .value (" ORDEREDDICT" , PyTreeKind::OrderedDict, " A collections.OrderedDict." )
228
+ .value (" DEFAULTDICT" , PyTreeKind::DefaultDict, " A collections.defaultdict." )
229
+ .value (" DEQUE" , PyTreeKind::Deque, " A collections.deque." )
230
+ .value (" STRUCTSEQUENCE" , PyTreeKind::StructSequence, " A PyStructSequence." )
231
+ .finalize ();
232
+ auto PyTreeKindTypeObject = py::getattr (mod, " PyTreeKind" );
233
+ #else
208
234
auto PyTreeKindTypeObject =
209
235
py::enum_<PyTreeKind>(mod, " PyTreeKind" , " The kind of a pytree node." , py::module_local ())
210
236
.value (" CUSTOM" , PyTreeKind::Custom, " A custom type." )
@@ -218,6 +244,7 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
218
244
.value (" DEFAULTDICT" , PyTreeKind::DefaultDict, " A collections.defaultdict." )
219
245
.value (" DEQUE" , PyTreeKind::Deque, " A collections.deque." )
220
246
.value (" STRUCTSEQUENCE" , PyTreeKind::StructSequence, " A PyStructSequence." );
247
+ #endif
221
248
auto * const PyTreeKind_Type = reinterpret_cast <PyTypeObject*>(PyTreeKindTypeObject.ptr ());
222
249
PyTreeKind_Type->tp_name = " optree.PyTreeKind" ;
223
250
py::setattr (PyTreeKindTypeObject.ptr (), Py_Get_ID (__module__), Py_Get_ID (optree));
@@ -442,15 +469,23 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
442
469
443
470
#undef def_method_pos_only
444
471
472
+ // Make the types immutable to avoid attribute assignment, modification, and deletion.
445
473
#ifdef Py_TPFLAGS_IMMUTABLETYPE
446
474
PyTreeKind_Type->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE;
447
475
PyTreeSpec_Type->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE;
448
476
PyTreeIter_Type->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE;
477
+
478
+ #ifndef PYBIND11_HAS_NATIVE_ENUM
479
+ // Only run Python C API `PyType_Ready(type)` on C++ types.
480
+ // Re-running `PyType_Ready` for native Python enums can cause unexpected behavior,
481
+ // such as infinite recursion for `repr(e)` and `hash(e)`.
449
482
PyTreeKind_Type->tp_flags &= ~Py_TPFLAGS_READY;
483
+ #endif
450
484
PyTreeSpec_Type->tp_flags &= ~Py_TPFLAGS_READY;
451
485
PyTreeIter_Type->tp_flags &= ~Py_TPFLAGS_READY;
452
486
#endif
453
487
488
+ // Re-ready types or do consistency checks for the types.
454
489
if (PyType_Ready (PyTreeKind_Type) < 0 ) [[unlikely]] {
455
490
INTERNAL_ERROR (" `PyType_Ready(&PyTreeKind_Type)` failed." );
456
491
}
0 commit comments