Skip to content

Commit 04a0f21

Browse files
committed
feat(PyTreeKind): use pybind11::native_enum to create enum class PyTreeKind
1 parent e5130aa commit 04a0f21

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed

include/optree/pymacros.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,16 @@ limitations under the License.
2121

2222
#include <pybind11/pybind11.h>
2323

24-
namespace py = pybind11;
25-
26-
#if PY_VERSION_HEX < 0x03090000 // Python 3.9
24+
#if !(defined(PY_VERSION_HEX) && PY_VERSION_HEX >= 0x03090000) // Python 3.9
2725
#error "Python 3.9 or newer is required."
2826
#endif
2927

28+
#if !(defined(PYBIND11_VERSION_HEX) && PYBIND11_VERSION_HEX >= 0x020C00F0) // pybind11 2.12.0
29+
#error "pybind11 2.12.0 or newer is required."
30+
#endif
31+
32+
namespace py = pybind11;
33+
3034
#ifndef Py_ALWAYS_INLINE
3135
#define Py_ALWAYS_INLINE
3236
#endif

src/optree.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,16 @@ limitations under the License.
2626
#include <pybind11/pybind11.h>
2727
#include <pybind11/stl.h>
2828

29+
#if defined(PYBIND11_HAS_NATIVE_ENUM) || \
30+
(defined(PYBIND11_INTERNALS_VERSION) && PYBIND11_INTERNALS_VERSION >= 8)
31+
#ifndef PYBIND11_HAS_NATIVE_ENUM
32+
#define PYBIND11_HAS_NATIVE_ENUM
33+
#endif
34+
#include <pybind11/native_enum.h>
35+
#else
36+
#undef PYBIND11_HAS_NATIVE_ENUM
37+
#endif
38+
2939
namespace optree {
3040

3141
py::module_ GetCxxModule(const std::optional<py::module_>& module) {
@@ -205,6 +215,22 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
205215
#define def_method_pos_only(...) def(__VA_ARGS__)
206216
#endif
207217

218+
#ifdef PYBIND11_HAS_NATIVE_ENUM
219+
py::native_enum<PyTreeKind>(mod, "PyTreeKind", "enum.IntEnum", "The kind of a pytree node.")
220+
.value("CUSTOM", PyTreeKind::Custom, "A custom type.")
221+
.value("LEAF", PyTreeKind::Leaf, "A opaque leaf node.")
222+
.value("NONE", PyTreeKind::None, "None.")
223+
.value("TUPLE", PyTreeKind::Tuple, "A tuple.")
224+
.value("LIST", PyTreeKind::List, "A list.")
225+
.value("DICT", PyTreeKind::Dict, "A dict.")
226+
.value("NAMEDTUPLE", PyTreeKind::NamedTuple, "A collections.namedtuple.")
227+
.value("ORDEREDDICT", PyTreeKind::OrderedDict, "A collections.OrderedDict.")
228+
.value("DEFAULTDICT", PyTreeKind::DefaultDict, "A collections.defaultdict.")
229+
.value("DEQUE", PyTreeKind::Deque, "A collections.deque.")
230+
.value("STRUCTSEQUENCE", PyTreeKind::StructSequence, "A PyStructSequence.")
231+
.finalize();
232+
auto PyTreeKindTypeObject = py::getattr(mod, "PyTreeKind");
233+
#else
208234
auto PyTreeKindTypeObject =
209235
py::enum_<PyTreeKind>(mod, "PyTreeKind", "The kind of a pytree node.", py::module_local())
210236
.value("CUSTOM", PyTreeKind::Custom, "A custom type.")
@@ -218,6 +244,7 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
218244
.value("DEFAULTDICT", PyTreeKind::DefaultDict, "A collections.defaultdict.")
219245
.value("DEQUE", PyTreeKind::Deque, "A collections.deque.")
220246
.value("STRUCTSEQUENCE", PyTreeKind::StructSequence, "A PyStructSequence.");
247+
#endif
221248
auto* const PyTreeKind_Type = reinterpret_cast<PyTypeObject*>(PyTreeKindTypeObject.ptr());
222249
PyTreeKind_Type->tp_name = "optree.PyTreeKind";
223250
py::setattr(PyTreeKindTypeObject.ptr(), Py_Get_ID(__module__), Py_Get_ID(optree));
@@ -442,15 +469,23 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
442469

443470
#undef def_method_pos_only
444471

472+
// Make the types immutable to avoid attribute assignment, modification, and deletion.
445473
#ifdef Py_TPFLAGS_IMMUTABLETYPE
446474
PyTreeKind_Type->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE;
447475
PyTreeSpec_Type->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE;
448476
PyTreeIter_Type->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE;
477+
478+
#ifndef PYBIND11_HAS_NATIVE_ENUM
479+
// Only run Python C API `PyType_Ready(type)` on C++ types.
480+
// Re-running `PyType_Ready` for native Python enums can cause unexpected behavior,
481+
// such as infinite recursion for `repr(e)` and `hash(e)`.
449482
PyTreeKind_Type->tp_flags &= ~Py_TPFLAGS_READY;
483+
#endif
450484
PyTreeSpec_Type->tp_flags &= ~Py_TPFLAGS_READY;
451485
PyTreeIter_Type->tp_flags &= ~Py_TPFLAGS_READY;
452486
#endif
453487

488+
// Re-ready types or do consistency checks for the types.
454489
if (PyType_Ready(PyTreeKind_Type) < 0) [[unlikely]] {
455490
INTERNAL_ERROR("`PyType_Ready(&PyTreeKind_Type)` failed.");
456491
}

0 commit comments

Comments
 (0)