Skip to content

Allow to use prefix trees as nodes to parse function returns. #406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
`@pytask.mark.produces`.
- {pull}`402` replaces ABCs with protocols allowing for more flexibility for users
implementing their own nodes.
- {pull}`404` allows to use function returns to define task products.
- {pull}`405` allows to match function returns to node annotations with prefix trees.

## 0.3.2 - 2023-06-07

Expand Down
19 changes: 14 additions & 5 deletions src/_pytask/collect_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ def _find_common_ancestor_of_all_nodes(
all_paths.extend(
x.path for x in tree_leaves(task.depends_on) if isinstance(x, PPathNode)
)
all_paths.extend(x.path for x in tree_leaves(task.produces))
all_paths.extend(
x.path for x in tree_leaves(task.produces) if isinstance(x, PPathNode)
)

common_ancestor = find_common_ancestor(*all_paths, *paths)

Expand Down Expand Up @@ -219,10 +221,17 @@ def _print_collected_tasks(

task_branch.add(Text.assemble(FILE_ICON, "<Dependency ", text, ">"))

for node in sorted(tree_leaves(task.produces), key=lambda x: x.path):
reduced_node_name = str(relative_to(node.path, common_ancestor))
url_style = create_url_style_for_path(node.path, editor_url_scheme)
text = Text(reduced_node_name, style=url_style)
for node in sorted(
tree_leaves(task.produces), key=lambda x: getattr(x, "path", x.name)
):
if isinstance(node, PPathNode):
reduced_node_name = str(relative_to(node.path, common_ancestor))
url_style = create_url_style_for_path(
node.path, editor_url_scheme
)
text = Text(reduced_node_name, style=url_style)
else:
text = Text(node.name)
task_branch.add(Text.assemble(FILE_ICON, "<Product ", text, ">"))

console.print(tree)
11 changes: 2 additions & 9 deletions src/_pytask/collect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from _pytask.mark_utils import remove_marks
from _pytask.models import NodeInfo
from _pytask.node_protocols import Node
from _pytask.node_protocols import PPathNode
from _pytask.nodes import ProductType
from _pytask.nodes import PythonNode
from _pytask.shared import find_duplicates
Expand Down Expand Up @@ -580,15 +579,9 @@ def _collect_product(
"type 'str' and 'pathlib.Path' or the same values nested in "
f"tuples, lists, and dictionaries. Here, {node} has type {type(node)}."
)
# The parameter defaults only support Path objects.
if not isinstance(node, (Path, PPathNode)) and not is_string_allowed:
raise ValueError(
"If you declare products with 'Annotated[..., Product]', only values of "
"type 'pathlib.Path' optionally nested in tuples, lists, and "
f"dictionaries are allowed. Here, {node!r} has type {type(node)}."
)

if isinstance(node, str):
# If we encounter a string and it is allowed, convert it to a path.
if isinstance(node, str) and is_string_allowed:
node = Path(node)
node_info = node_info._replace(value=node)

Expand Down
22 changes: 11 additions & 11 deletions src/_pytask/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from pathlib import Path
from typing import Any
from typing import Callable
from typing import NoReturn
from typing import TYPE_CHECKING

from _pytask.node_protocols import MetaNode
Expand Down Expand Up @@ -82,17 +81,18 @@ def execute(self, **kwargs: Any) -> None:
if "return" in self.produces:
structure_out = tree_structure(out)
structure_return = tree_structure(self.produces["return"])
if not structure_out == structure_return:
# strict must be false when none is leaf.
if not structure_return.is_prefix(structure_out, strict=False):
raise ValueError(
"The structure of the function return does not match the structure "
f"of the return annotation.\n\nFunction return: {structure_out}\n\n"
f"Return annotation: {structure_return}"
"The structure of the return annotation is not a subtree of the "
"structure of the function return.\n\nFunction return: "
f"{structure_out}\n\nReturn annotation: {structure_return}"
)

for out_, return_ in zip(
tree_leaves(out), tree_leaves(self.produces["return"])
):
return_.save(out_)
nodes = tree_leaves(self.produces["return"])
values = structure_return.flatten_up_to(out)
for node, value in zip(nodes, values):
node.save(value)

def add_report_section(self, when: str, key: str, content: str) -> None:
"""Add sections which will be displayed in report like stdout or stderr."""
Expand Down Expand Up @@ -174,9 +174,9 @@ def load(self) -> Any:
"""Load the value."""
return self.value

def save(self, value: Any) -> NoReturn:
def save(self, value: Any) -> None:
"""Save the value."""
raise NotImplementedError
self.value = value

def from_annot(self, value: Any) -> None:
"""Set the value from a function annotation."""
Expand Down
13 changes: 13 additions & 0 deletions src/_pytask/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@

import functools
from pathlib import Path
from typing import Any

import optree
from optree import PyTree
from optree import tree_flatten_with_path as _optree_tree_flatten_with_path
from optree import tree_leaves as _optree_tree_leaves
from optree import tree_map as _optree_tree_map
from optree import tree_map_with_path as _optree_tree_map_with_path
from optree import tree_structure as _optree_tree_structure


__all__ = [
"tree_flatten_with_path",
"tree_leaves",
"tree_map",
"tree_map_with_path",
Expand All @@ -34,3 +37,13 @@
tree_structure = functools.partial(
_optree_tree_structure, none_is_leaf=True, namespace="pytask"
)
tree_flatten_with_path = functools.partial(
_optree_tree_flatten_with_path, none_is_leaf=True, namespace="pytask"
)


def tree_index(path: tuple[Any, ...], tree: PyTree) -> Any:
"""Index a tree with a path."""
if not path:
return tree
return tree_index(path[1:], tree[path[0]])
4 changes: 2 additions & 2 deletions tests/test_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ def task_write_text(produces="out.txt"):

@pytest.mark.end_to_end()
def test_collect_string_product_raises_error_with_annotation(runner, tmp_path):
"""The string is not converted to a path."""
source = """
from pytask import Product
from typing_extensions import Annotated
Expand All @@ -349,8 +350,7 @@ def task_write_text(out: Annotated[str, Product] = "out.txt") -> None:
"""
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
result = runner.invoke(cli, [tmp_path.as_posix()])
assert result.exit_code == ExitCode.COLLECTION_FAILED
assert "If you declare products with 'Annotated[..., Product]'" in result.output
assert result.exit_code == ExitCode.FAILED


@pytest.mark.end_to_end()
Expand Down
27 changes: 27 additions & 0 deletions tests/test_collect_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,3 +610,30 @@ def task_example(
result = runner.invoke(cli, ["collect", "--nodes", tmp_path.as_posix()])
assert result.exit_code == ExitCode.OK
assert "node-name" in result.output


@pytest.mark.end_to_end()
def test_more_nested_pytree_and_python_node_as_return(runner, tmp_path):
source = """
from pathlib import Path
from typing import Any
from typing_extensions import Annotated
from pytask import PythonNode
from typing import Dict

nodes = [
PythonNode(name="dict"),
(PythonNode(name="tuple1"), PythonNode(name="tuple2")),
PythonNode(name="int")
]

def task_example() -> Annotated[Dict[str, str], nodes]:
return [{"first": "a", "second": "b"}, (1, 2), 1]
"""
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
result = runner.invoke(cli, ["collect", "--nodes", tmp_path.as_posix()])
assert result.exit_code == ExitCode.OK
assert "dict" in result.output
assert "tuple1" in result.output
assert "tuple2" in result.output
assert "int" in result.output
40 changes: 40 additions & 0 deletions tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,3 +666,43 @@ def task_example() -> Annotated[str, (node1, node2)]:
assert result.exit_code == ExitCode.FAILED
assert "Function return: PyTreeSpec(*, NoneIsLeaf)" in result.output
assert "Return annotation: PyTreeSpec((*, *), NoneIsLeaf)" in result.output


@pytest.mark.end_to_end()
def test_pytree_and_python_node_as_return(runner, tmp_path):
source = """
from pathlib import Path
from typing import Any
from typing_extensions import Annotated
from pytask import PythonNode
from typing import Dict

def task_example() -> Annotated[Dict[str, str], PythonNode(name="result")]:
return {"first": "a", "second": "b"}
"""
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
result = runner.invoke(cli, [tmp_path.as_posix()])
assert result.exit_code == ExitCode.OK


@pytest.mark.end_to_end()
def test_more_nested_pytree_and_python_node_as_return(runner, tmp_path):
source = """
from pathlib import Path
from typing import Any
from typing_extensions import Annotated
from pytask import PythonNode
from typing import Dict

nodes = [
PythonNode(name="dict"),
(PythonNode(name="tuple1"), PythonNode(name="tuple2")),
PythonNode(name="int")
]

def task_example() -> Annotated[Dict[str, str], nodes]:
return [{"first": "a", "second": "b"}, (1, 2), 1]
"""
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
result = runner.invoke(cli, [tmp_path.as_posix()])
assert result.exit_code == ExitCode.OK
18 changes: 18 additions & 0 deletions tests/test_tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from _pytask.outcomes import ExitCode
from _pytask.tree_util import tree_map
from _pytask.tree_util import tree_structure
from pytask import cli
from pytask import main

Expand Down Expand Up @@ -80,3 +81,20 @@ def task_example(produces):
assert "0." in result.output
assert "Size of Products" in result.output
assert "86 bytes" in result.output


@pytest.mark.unit()
@pytest.mark.parametrize(
("prefix_tree", "full_tree", "strict", "expected"),
[
# This is why strict cannot be true when parsing function returns.
(1, 1, True, False),
(1, 1, False, True),
({"a": 1, "b": 1}, {"a": 1, "b": {"c": 1, "d": 1}}, False, True),
({"a": 1, "b": 1}, {"a": 1, "b": {"c": 1, "d": 1}}, True, True),
],
)
def test_is_prefix(prefix_tree, full_tree, strict, expected):
prefix_structure = tree_structure(prefix_tree)
full_tree_structure = tree_structure(full_tree)
assert prefix_structure.is_prefix(full_tree_structure, strict=strict) is expected