Skip to content

Add type annotations for nodes #253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 33 additions & 32 deletions anytree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,38 +8,39 @@
__description__ = """Powerful and Lightweight Python Tree Data Structure."""
__url__ = "https://github.com/c0fec0de/anytree"

from . import cachedsearch # noqa
from . import util # noqa
from .iterators import LevelOrderGroupIter # noqa
from .iterators import LevelOrderIter # noqa
from .iterators import PostOrderIter # noqa
from .iterators import PreOrderIter # noqa
from .iterators import ZigZagGroupIter # noqa
from .node import AnyNode # noqa
from .node import LightNodeMixin # noqa
from .node import LoopError # noqa
from .node import Node # noqa
from .node import NodeMixin # noqa
from .node import SymlinkNode # noqa
from .node import SymlinkNodeMixin # noqa
from .node import TreeError # noqa
from .render import AbstractStyle # noqa
from .render import AsciiStyle # noqa
from .render import ContRoundStyle # noqa
from .render import ContStyle # noqa
from .render import DoubleStyle # noqa
from .render import RenderTree # noqa
from .resolver import ChildResolverError # noqa
from .resolver import Resolver # noqa
from .resolver import ResolverError # noqa
from .resolver import RootResolverError # noqa
from .search import CountError # noqa
from .search import find # noqa
from .search import find_by_attr # noqa
from .search import findall # noqa
from .search import findall_by_attr # noqa
from .walker import Walker # noqa
from .walker import WalkError # noqa
# pylint: disable=useless-import-alias
from . import cachedsearch as cachedsearch # noqa
from . import util as util # noqa
from .iterators import LevelOrderGroupIter as LevelOrderGroupIter # noqa
from .iterators import LevelOrderIter as LevelOrderIter # noqa
from .iterators import PostOrderIter as PostOrderIter # noqa
from .iterators import PreOrderIter as PreOrderIter # noqa
from .iterators import ZigZagGroupIter as ZigZagGroupIter # noqa
from .node import AnyNode as AnyNode # noqa
from .node import LightNodeMixin as LightNodeMixin # noqa
from .node import LoopError as LoopError # noqa
from .node import Node as Node # noqa
from .node import NodeMixin as NodeMixin # noqa
from .node import SymlinkNode as SymlinkNode # noqa
from .node import SymlinkNodeMixin as SymlinkNodeMixin # noqa
from .node import TreeError as TreeError # noqa
from .render import AbstractStyle as AbstractStyle # noqa
from .render import AsciiStyle as AsciiStyle # noqa
from .render import ContRoundStyle as ContRoundStyle # noqa
from .render import ContStyle as ContStyle # noqa
from .render import DoubleStyle as DoubleStyle # noqa
from .render import RenderTree as RenderTree # noqa
from .resolver import ChildResolverError as ChildResolverError # noqa
from .resolver import Resolver as Resolver # noqa
from .resolver import ResolverError as ResolverError # noqa
from .resolver import RootResolverError as RootResolverError # noqa
from .search import CountError as CountError # noqa
from .search import find as find # noqa
from .search import find_by_attr as find_by_attr # noqa
from .search import findall as findall # noqa
from .search import findall_by_attr as findall_by_attr # noqa
from .walker import Walker as Walker # noqa
from .walker import WalkError as WalkError # noqa

# legacy
LevelGroupOrderIter = LevelOrderGroupIter
1 change: 0 additions & 1 deletion anytree/exporter/mermaidexporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@


class MermaidExporter:

"""
Mermaid Exporter.

Expand Down
13 changes: 7 additions & 6 deletions anytree/iterators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
* :any:`ZigZagGroupIter`: iterate over tree using level-order strategy returning group for every level
"""

from .abstractiter import AbstractIter # noqa
from .levelordergroupiter import LevelOrderGroupIter # noqa
from .levelorderiter import LevelOrderIter # noqa
from .postorderiter import PostOrderIter # noqa
from .preorderiter import PreOrderIter # noqa
from .zigzaggroupiter import ZigZagGroupIter # noqa
# pylint: disable=useless-import-alias
from .abstractiter import AbstractIter as AbstractIter # noqa
from .levelordergroupiter import LevelOrderGroupIter as LevelOrderGroupIter # noqa
from .levelorderiter import LevelOrderIter as LevelOrderIter # noqa
from .postorderiter import PostOrderIter as PostOrderIter # noqa
from .preorderiter import PreOrderIter as PreOrderIter # noqa
from .zigzaggroupiter import ZigZagGroupIter as ZigZagGroupIter # noqa
48 changes: 37 additions & 11 deletions anytree/iterators/abstractiter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Generic, TypeVar

import six

if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Iterator

from typing_extensions import Self

from ..node.lightnodemixin import LightNodeMixin
from ..node.nodemixin import NodeMixin


NodeT_co = TypeVar("NodeT_co", bound="NodeMixin[Any] | LightNodeMixin[Any]", covariant=True)


class AbstractIter(six.Iterator):
class AbstractIter(Generic[NodeT_co], six.Iterator):
# pylint: disable=R0205
"""
Iterate over tree starting at `node`.
Expand All @@ -14,14 +29,20 @@ class AbstractIter(six.Iterator):
maxlevel (int): maximum descending in the node hierarchy.
"""

def __init__(self, node, filter_=None, stop=None, maxlevel=None):
def __init__(
self,
node: NodeT_co,
filter_: Callable[[NodeT_co], bool] | None = None,
stop: Callable[[NodeT_co], bool] | None = None,
maxlevel: int | None = None,
) -> None:
self.node = node
self.filter_ = filter_
self.stop = stop
self.maxlevel = maxlevel
self.__iter = None
self.__iter: Iterator[NodeT_co] | None = None

def __init(self):
def __init(self) -> Iterator[NodeT_co]:
node = self.node
maxlevel = self.maxlevel
filter_ = self.filter_ or AbstractIter.__default_filter
Expand All @@ -30,31 +51,36 @@ def __init(self):
return self._iter(children, filter_, stop, maxlevel)

@staticmethod
def __default_filter(node):
def __default_filter(node: NodeT_co) -> bool:
# pylint: disable=W0613
return True

@staticmethod
def __default_stop(node):
def __default_stop(node: NodeT_co) -> bool:
# pylint: disable=W0613
return False

def __iter__(self):
def __iter__(self) -> Self:
return self

def __next__(self):
def __next__(self) -> NodeT_co:
if self.__iter is None:
self.__iter = self.__init()
return next(self.__iter)

@staticmethod
def _iter(children, filter_, stop, maxlevel):
def _iter(
children: Iterable[NodeT_co],
filter_: Callable[[NodeT_co], bool],
stop: Callable[[NodeT_co], bool],
maxlevel: int | None,
) -> Iterator[NodeT_co]:
raise NotImplementedError() # pragma: no cover

@staticmethod
def _abort_at_level(level, maxlevel):
def _abort_at_level(level: int, maxlevel: int | None) -> bool:
return maxlevel is not None and level > maxlevel

@staticmethod
def _get_children(children, stop):
def _get_children(children: Iterable[NodeT_co], stop: Callable[[NodeT_co], bool]) -> list[Any]:
return [child for child in children if not stop(child)]
17 changes: 9 additions & 8 deletions anytree/node/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
* :any:`LightNodeMixin`: A :any:`NodeMixin` using slots.
"""

from .anynode import AnyNode # noqa
from .exceptions import LoopError # noqa
from .exceptions import TreeError # noqa
from .lightnodemixin import LightNodeMixin # noqa
from .node import Node # noqa
from .nodemixin import NodeMixin # noqa
from .symlinknode import SymlinkNode # noqa
from .symlinknodemixin import SymlinkNodeMixin # noqa
# pylint: disable=useless-import-alias
from .anynode import AnyNode as AnyNode # noqa
from .exceptions import LoopError as LoopError # noqa
from .exceptions import TreeError as TreeError # noqa
from .lightnodemixin import LightNodeMixin as LightNodeMixin # noqa
from .node import Node as Node # noqa
from .nodemixin import NodeMixin as NodeMixin # noqa
from .symlinknode import SymlinkNode as SymlinkNode # noqa
from .symlinknodemixin import SymlinkNodeMixin as SymlinkNodeMixin # noqa
14 changes: 10 additions & 4 deletions anytree/node/anynode.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
# -*- coding: utf-8 -*-

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from .nodemixin import NodeMixin
from .util import _repr

if TYPE_CHECKING:
from collections.abc import Iterable

class AnyNode(NodeMixin):

class AnyNode(NodeMixin["AnyNode"]):
"""
A generic tree node with any `kwargs`.

Expand Down Expand Up @@ -92,12 +99,11 @@ class AnyNode(NodeMixin):
... ])
"""

def __init__(self, parent=None, children=None, **kwargs):

def __init__(self, parent: AnyNode | None = None, children: Iterable[AnyNode] | None = None, **kwargs: Any) -> None:
self.__dict__.update(kwargs)
self.parent = parent
if children:
self.children = children

def __repr__(self):
def __repr__(self) -> str:
return _repr(self)
Loading