Skip to content

feat(PyTreeKind): use pybind11::native_enum to create enum class PyTreeKind #214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,30 @@ insert_final_newline = true

[*.py]
indent_size = 4
src_paths=optree,tests

[*.{yaml,yml}]
[*.{cpp,hpp,cxx,cc,c,h,cu,cuh}]
indent_size = 4

[*.{yaml,yml,json}]
indent_size = 2

[.clang-{format,tidy}]
indent_size = 2

[Makefile]
indent_style = tab

[*.sh]
indent_size = 4

[*.bat]
indent_size = 4
end_of_line = crlf

[*.md]
indent_size = 2
x-soft-wrap-text = true

[*.rst]
indent_size = 4
x-soft-wrap-text = true

[Makefile]
indent_style = tab

[*.{cpp,h,cu,cuh}]
indent_size = 4
15 changes: 11 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ env:
jobs:
build-sdist:
name: Build sdist
if: github.repository == 'metaopt/optree' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
if: |
github.repository_owner == 'metaopt' &&
(github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
Expand Down Expand Up @@ -86,7 +88,9 @@ jobs:

build-wheels:
name: Build wheels for Python ${{ matrix.python-version }} on ${{ matrix.runner }} (${{ matrix.archs }})
if: github.repository == 'metaopt/optree' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
if: |
github.repository_owner == 'metaopt' &&
(github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
runs-on: ${{ matrix.runner }}
strategy:
matrix:
Expand Down Expand Up @@ -241,7 +245,9 @@ jobs:

list-artifacts:
name: List artifacts
if: github.repository == 'metaopt/optree' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
if: |
github.repository_owner == 'metaopt' &&
(github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
runs-on: ubuntu-latest
needs: [build-sdist, build-wheels]
timeout-minutes: 15
Expand Down Expand Up @@ -275,7 +281,8 @@ jobs:
runs-on: ubuntu-latest
needs: [list-artifacts]
if: |
github.repository == 'metaopt/optree' && github.event_name != 'pull_request' &&
github.repository_owner == 'metaopt' &&
github.event_name != 'pull_request' &&
(github.event_name != 'workflow_dispatch' || github.event.inputs.task == 'build-and-publish') &&
(github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
timeout-minutes: 15
Expand Down
12 changes: 12 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ jobs:
run: |
python -m pip install wheel pybind11 -r docs/requirements.txt

- name: Install nightly pybind11
shell: bash
if: |
github.event_name == 'pull_request' &&
contains(github.event.pull_request.labels.*.name, 'test-with-nightly-pybind11')
run: |
python .github/workflows/set_setup_requires.py
echo "::group::pyproject.toml"
cat pyproject.toml
echo "::endgroup::"
python -m pip install --force-reinstall 'pybind11 @ git+https://github.com/pybind/pybind11.git#egg=pybind11'

- name: Install OpTree
run: |
python -m pip install -v --no-build-isolation --editable '.[lint]'
Expand Down
24 changes: 24 additions & 0 deletions .github/workflows/set_setup_requires.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python3

# pylint: disable=missing-module-docstring

import re
from pathlib import Path


ROOT = Path(__file__).absolute().parents[2]

PYPROJECT_FILE = ROOT / 'pyproject.toml'


if __name__ == '__main__':
PYPROJECT_CONTENT = PYPROJECT_FILE.read_text(encoding='utf-8')

PYPROJECT_FILE.write_text(
data=re.sub(
r'(requires\s*=\s*\[.*"\s*)\bpybind11\b[^"]*(\s*".*\])',
r'\g<1>pybind11 @ git+https://github.com/pybind/pybind11.git#egg=pybind11\g<2>',
string=PYPROJECT_CONTENT,
),
encoding='utf-8',
)
22 changes: 20 additions & 2 deletions .github/workflows/tests-with-pydebug.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ env:
PYTHON: "python" # to be updated
PYTHON_TAG: "py3" # to be updated
PYTHON_VERSION: "3" # to be updated
pybind11_VERSION: "stable" # to be updated
PYENV_ROOT: "~/.pyenv" # to be updated
COLUMNS: "128"
COLUMNS: "100"

jobs:
test:
Expand Down Expand Up @@ -284,6 +285,23 @@ jobs:

cdb -version

- name: Use nightly pybind11
shell: bash
if: |
github.event_name == 'pull_request' &&
contains(github.event.pull_request.labels.*.name, 'test-with-nightly-pybind11')
run: |
${{ env.PYTHON }} .github/workflows/set_setup_requires.py
echo "::group::pyproject.toml"
cat pyproject.toml
echo "::endgroup::"
echo "pybind11_VERSION=HEAD" | tee -a "${GITHUB_ENV}"

- name: Test buildable without Python frontend
if: runner.os != 'Windows'
run: |
make cmake-build PYTHON="${{ env.PYTHON }}" && make clean

- name: Install OpTree
run: |
${{ env.PYTHON }} -m pip install -v --editable '.[test]'
Expand Down Expand Up @@ -311,7 +329,7 @@ jobs:
find . -type f -name '*.py[cod]' -delete
find . -depth -type d -name "__pycache__" -exec rm -r "{}" +
if git status --ignored --porcelain | grep -qvE '/$'; then
ls -alh $(git status --ignored --porcelain | grep -vE '/$' | cut -d ' ' -f2)
ls -alh $(git status --ignored --porcelain | grep -vE '/$' | grep -oE '\S+$')
fi

- name: Collect backtraces from coredumps (if any)
Expand Down
32 changes: 28 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ env:
PYTHONUNBUFFERED: "1"
PYTHON: "python" # to be updated
PYTHON_TAG: "py3" # to be updated
COLUMNS: "128"
pybind11_VERSION: "stable" # to be updated
COLUMNS: "100"

jobs:
test:
Expand Down Expand Up @@ -142,6 +143,18 @@ jobs:

cdb -version

- name: Use nightly pybind11
shell: bash
if: |
github.event_name == 'pull_request' &&
contains(github.event.pull_request.labels.*.name, 'test-with-nightly-pybind11')
run: |
${{ env.PYTHON }} .github/workflows/set_setup_requires.py
echo "::group::pyproject.toml"
cat pyproject.toml
echo "::endgroup::"
echo "pybind11_VERSION=HEAD" | tee -a "${GITHUB_ENV}"

- name: Test installable with C++17
shell: bash
if: runner.os != 'Windows'
Expand All @@ -150,14 +163,25 @@ jobs:
set -x
${{ env.PYTHON }} -m venv venv &&
source venv/bin/activate &&
export OPTREE_CXX_WERROR=OFF CMAKE_CXX_STANDARD=17 &&
${{ env.PYTHON }} -m pip install -v . &&
OPTREE_CXX_WERROR=OFF CMAKE_CXX_STANDARD=17 \
${{ env.PYTHON }} -m pip install -v . &&
pushd venv &&
${{ env.PYTHON }} -X dev -Walways -Werror -c 'import optree' &&
popd &&
rm -rf venv
)

- name: Test buildable without Python frontend
if: runner.os != 'Windows'
run: |
(
set -x
${{ env.PYTHON }} -m venv venv &&
source venv/bin/activate &&
make cmake-build && make clean &&
rm -rf venv
)

- name: Install test dependencies
shell: bash
run: |
Expand Down Expand Up @@ -236,7 +260,7 @@ jobs:
find . -type f -name '*.py[cod]' -delete
find . -depth -type d -name "__pycache__" -exec rm -r "{}" +
if git status --ignored --porcelain | grep -qvE '/$'; then
ls -alh $(git status --ignored --porcelain | grep -vE '/$' | cut -d ' ' -f2)
ls -alh $(git status --ignored --porcelain | grep -vE '/$' | grep -oE '\S+$')
fi

- name: Collect backtraces from coredumps (if any)
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
fail_fast: true
- id: debug-statements
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v20.1.3
rev: v20.1.4
hooks:
- id: clang-format
- repo: https://github.com/cpplint/cpplint
Expand Down
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Explicitly set recursion limit for recursion tests by [@XuehaiPan](https://github.com/XuehaiPan) in [#207](https://github.com/metaopt/optree/pull/207).
- Dump build-time meta-information to C extension [@XuehaiPan](https://github.com/XuehaiPan) in [#215](https://github.com/metaopt/optree/pull/215).
- Dump build-time meta-information to C extension by [@XuehaiPan](https://github.com/XuehaiPan) in [#215](https://github.com/metaopt/optree/pull/215).
- Use `pybind11::native_enum` to create enum class `PyTreeKind` if available by [@XuehaiPan](https://github.com/XuehaiPan) in [#214](https://github.com/metaopt/optree/pull/214).

### Changed

Expand All @@ -23,7 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Never call `PyType_Ready` twice and use `PyType_Modified` instead by [@XuehaiPan](https://github.com/XuehaiPan) in [#214](https://github.com/metaopt/optree/pull/214).

### Removed

Expand Down
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ string(STRIP "${CMAKE_CXX_FLAGS_RELEASE}" CMAKE_CXX_FLAGS_RELEASE)
message(STATUS "CXX flags: \"${CMAKE_CXX_FLAGS}\"")
message(STATUS "CXX flags (Debug): \"${CMAKE_CXX_FLAGS_DEBUG}\"")
message(STATUS "CXX flags (Release): \"${CMAKE_CXX_FLAGS_RELEASE}\"")
if(NOT "${CMAKE_BUILD_TYPE}" STREQUAL "Debug" AND NOT "${CMAKE_BUILD_TYPE}" STREQUAL "Release")
string(TOUPPER "${CMAKE_BUILD_TYPE}" CMAKE_BUILD_TYPE_UPPER)
string(STRIP "${CMAKE_CXX_FLAGS_${CMAKE_BUILD_TYPE_UPPER}}" "CMAKE_CXX_FLAGS_${CMAKE_BUILD_TYPE_UPPER}")
message(STATUS "CXX flags (${CMAKE_BUILD_TYPE}): \"${CMAKE_CXX_FLAGS_${CMAKE_BUILD_TYPE_UPPER}}\"")
endif()

if(MSVC AND NOT "$ENV{VSCMD_ARG_TGT_ARCH}" STREQUAL "")
message(STATUS "Use VSCMD_ARG_TGT_ARCH: \"$ENV{VSCMD_ARG_TGT_ARCH}\"")
Expand Down
20 changes: 11 additions & 9 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ PATH := $(PATH):$(GOBIN)
PYTHON ?= $(shell command -v python3 || command -v python)
PYTEST ?= $(PYTHON) -X dev -m pytest -Walways
PYTESTOPTS ?=
CMAKE_BUILD_TYPE ?= Debug
CMAKE_BUILD_TYPE_LOWER = $(shell $(PYTHON) -c 'print("$(CMAKE_BUILD_TYPE)".lower())')
CMAKE_CXX_STANDARD ?= 20
OPTREE_CXX_WERROR ?= ON
_GLIBCXX_USE_CXX11_ABI ?= 1
Expand Down Expand Up @@ -196,9 +198,9 @@ xdoctest doctest: xdoctest-install
.PHONY: cmake-configure
cmake-configure: cmake-install
cmake --version
cmake -S . -B cmake-build-debug \
cmake -S . -B cmake-build-$(CMAKE_BUILD_TYPE_LOWER) \
--fresh \
-DCMAKE_BUILD_TYPE=Debug \
-DCMAKE_BUILD_TYPE="$(CMAKE_BUILD_TYPE)" \
-DCMAKE_CXX_STANDARD="$(CMAKE_CXX_STANDARD)" \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DPython_EXECUTABLE="$(PYTHON)" \
Expand All @@ -207,7 +209,7 @@ cmake-configure: cmake-install

.PHONY: cmake cmake-build
cmake cmake-build: cmake-configure
cmake --build cmake-build-debug --parallel
cmake --build cmake-build-$(CMAKE_BUILD_TYPE_LOWER) --parallel

.PHONY: clang-format
clang-format: clang-format-install
Expand All @@ -217,7 +219,7 @@ clang-format: clang-format-install
.PHONY: clang-tidy
clang-tidy: clang-tidy-install cmake-configure
clang-tidy --version
clang-tidy --extra-arg="-v" --fix -p=cmake-build-debug $(CXX_FILES)
clang-tidy --extra-arg="-v" --fix -p=cmake-build-$(CMAKE_BUILD_TYPE_LOWER) $(CXX_FILES)

.PHONY: cpplint
cpplint: cpplint-install
Expand All @@ -233,21 +235,21 @@ addlicense: addlicense-install

.PHONY: docstyle
docstyle: docs-install
make -C docs clean || true
$(PYTHON) -m doc8 docs && make -C docs html SPHINXOPTS="-W"
$(MAKE) -C docs clean || true
$(PYTHON) -m doc8 docs && $(MAKE) -C docs html SPHINXOPTS="-W"

.PHONY: docs
docs: docs-install
$(PYTHON) -m sphinx_autobuild --watch $(PROJECT_PATH) --open-browser docs/source docs/build

.PHONY: spelling
spelling: docs-install
make -C docs clean || true
make -C docs spelling SPHINXOPTS="-W"
$(MAKE) -C docs clean || true
$(MAKE) -C docs spelling SPHINXOPTS="-W"

.PHONY: clean-docs
clean-docs:
make -C docs clean || true
$(MAKE) -C docs clean || true

# Utility Functions

Expand Down
10 changes: 7 additions & 3 deletions include/optree/pymacros.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,16 @@ limitations under the License.

#include <pybind11/pybind11.h>

namespace py = pybind11;

#if PY_VERSION_HEX < 0x03090000 // Python 3.9
#if !(defined(PY_VERSION_HEX) && PY_VERSION_HEX >= 0x03090000) // Python 3.9
#error "Python 3.9 or newer is required."
#endif

#if !(defined(PYBIND11_VERSION_HEX) && PYBIND11_VERSION_HEX >= 0x020C00F0) // pybind11 2.12.0
#error "pybind11 2.12.0 or newer is required."
#endif

namespace py = pybind11;

#ifndef Py_ALWAYS_INLINE
#define Py_ALWAYS_INLINE
#endif
Expand Down
2 changes: 2 additions & 0 deletions include/optree/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ enum class PyTreeKind : std::uint8_t {
DefaultDict, // A collections.defaultdict
Deque, // A collections.deque
StructSequence, // A PyStructSequence
NumKinds, // Number of kinds (placed at the end)
};

constexpr PyTreeKind kCustom = PyTreeKind::Custom;
Expand All @@ -63,6 +64,7 @@ constexpr PyTreeKind kOrderedDict = PyTreeKind::OrderedDict;
constexpr PyTreeKind kDefaultDict = PyTreeKind::DefaultDict;
constexpr PyTreeKind kDeque = PyTreeKind::Deque;
constexpr PyTreeKind kStructSequence = PyTreeKind::StructSequence;
constexpr PyTreeKind kNumPyTreeKinds = PyTreeKind::NumKinds;

// Registry of custom node types.
class PyTreeTypeRegistry {
Expand Down
Loading
Loading