From 255681ad12bc36140e794e30225947487bcb6362 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Mon, 28 Aug 2023 08:40:14 +0200 Subject: [PATCH 1/5] Allow to match returns with prefix trees. --- src/_pytask/collect_utils.py | 11 ++--------- src/_pytask/nodes.py | 21 ++++++++++----------- src/_pytask/tree_util.py | 13 +++++++++++++ tests/test_collect.py | 4 ++-- tests/test_execute.py | 16 ++++++++++++++++ 5 files changed, 43 insertions(+), 22 deletions(-) diff --git a/src/_pytask/collect_utils.py b/src/_pytask/collect_utils.py index f2cdd300..7d9e7e20 100644 --- a/src/_pytask/collect_utils.py +++ b/src/_pytask/collect_utils.py @@ -18,7 +18,6 @@ from _pytask.mark_utils import remove_marks from _pytask.models import NodeInfo from _pytask.node_protocols import Node -from _pytask.node_protocols import PPathNode from _pytask.nodes import ProductType from _pytask.nodes import PythonNode from _pytask.shared import find_duplicates @@ -580,15 +579,9 @@ def _collect_product( "type 'str' and 'pathlib.Path' or the same values nested in " f"tuples, lists, and dictionaries. Here, {node} has type {type(node)}." ) - # The parameter defaults only support Path objects. - if not isinstance(node, (Path, PPathNode)) and not is_string_allowed: - raise ValueError( - "If you declare products with 'Annotated[..., Product]', only values of " - "type 'pathlib.Path' optionally nested in tuples, lists, and " - f"dictionaries are allowed. Here, {node!r} has type {type(node)}." - ) - if isinstance(node, str): + # If we encounter a string and it is allowed, convert it to a path. + if isinstance(node, str) and is_string_allowed: node = Path(node) node_info = node_info._replace(value=node) diff --git a/src/_pytask/nodes.py b/src/_pytask/nodes.py index 2d8088fd..cb5e18e6 100644 --- a/src/_pytask/nodes.py +++ b/src/_pytask/nodes.py @@ -6,7 +6,6 @@ from pathlib import Path from typing import Any from typing import Callable -from typing import NoReturn from typing import TYPE_CHECKING from _pytask.node_protocols import MetaNode @@ -82,17 +81,17 @@ def execute(self, **kwargs: Any) -> None: if "return" in self.produces: structure_out = tree_structure(out) structure_return = tree_structure(self.produces["return"]) - if not structure_out == structure_return: + if not structure_return.is_prefix(structure_out, strict=False): raise ValueError( - "The structure of the function return does not match the structure " - f"of the return annotation.\n\nFunction return: {structure_out}\n\n" - f"Return annotation: {structure_return}" + "The structure of the return annotation is not a subtree of the " + "structure of the function return.\n\nFunction return: " + f"{structure_out}\n\nReturn annotation: {structure_return}" ) - for out_, return_ in zip( - tree_leaves(out), tree_leaves(self.produces["return"]) - ): - return_.save(out_) + nodes = tree_leaves(self.produces["return"]) + values = structure_return.flatten_up_to(out) + for node, value in zip(nodes, values): + node.save(value) def add_report_section(self, when: str, key: str, content: str) -> None: """Add sections which will be displayed in report like stdout or stderr.""" @@ -174,9 +173,9 @@ def load(self) -> Any: """Load the value.""" return self.value - def save(self, value: Any) -> NoReturn: + def save(self, value: Any) -> None: """Save the value.""" - raise NotImplementedError + self.value = value def from_annot(self, value: Any) -> None: """Set the value from a function annotation.""" diff --git a/src/_pytask/tree_util.py b/src/_pytask/tree_util.py index 80edcc71..c0fa1b48 100644 --- a/src/_pytask/tree_util.py +++ b/src/_pytask/tree_util.py @@ -3,9 +3,11 @@ import functools from pathlib import Path +from typing import Any import optree from optree import PyTree +from optree import tree_flatten_with_path as _optree_tree_flatten_with_path from optree import tree_leaves as _optree_tree_leaves from optree import tree_map as _optree_tree_map from optree import tree_map_with_path as _optree_tree_map_with_path @@ -13,6 +15,7 @@ __all__ = [ + "tree_flatten_with_path", "tree_leaves", "tree_map", "tree_map_with_path", @@ -34,3 +37,13 @@ tree_structure = functools.partial( _optree_tree_structure, none_is_leaf=True, namespace="pytask" ) +tree_flatten_with_path = functools.partial( + _optree_tree_flatten_with_path, none_is_leaf=True, namespace="pytask" +) + + +def tree_index(path: tuple[Any, ...], tree: PyTree) -> Any: + """Index a tree with a path.""" + if not path: + return tree + return tree_index(path[1:], tree[path[0]]) diff --git a/tests/test_collect.py b/tests/test_collect.py index 19832858..0c56db3c 100644 --- a/tests/test_collect.py +++ b/tests/test_collect.py @@ -340,6 +340,7 @@ def task_write_text(produces="out.txt"): @pytest.mark.end_to_end() def test_collect_string_product_raises_error_with_annotation(runner, tmp_path): + """The string is not converted to a path.""" source = """ from pytask import Product from typing_extensions import Annotated @@ -349,8 +350,7 @@ def task_write_text(out: Annotated[str, Product] = "out.txt") -> None: """ tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) result = runner.invoke(cli, [tmp_path.as_posix()]) - assert result.exit_code == ExitCode.COLLECTION_FAILED - assert "If you declare products with 'Annotated[..., Product]'" in result.output + assert result.exit_code == ExitCode.FAILED @pytest.mark.end_to_end() diff --git a/tests/test_execute.py b/tests/test_execute.py index c9a0f942..83117ecb 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -666,3 +666,19 @@ def task_example() -> Annotated[str, (node1, node2)]: assert result.exit_code == ExitCode.FAILED assert "Function return: PyTreeSpec(*, NoneIsLeaf)" in result.output assert "Return annotation: PyTreeSpec((*, *), NoneIsLeaf)" in result.output + + +@pytest.mark.end_to_end() +def test_pytree_and_python_node(runner, tmp_path): + source = """ + from pathlib import Path + from typing import Any + from typing_extensions import Annotated + from pytask import PythonNode + + def task_example() -> Annotated[dict[str, str], PythonNode(name="result")]: + return {"first": "a", "second": "b"} + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + result = runner.invoke(cli, [tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK From d6d40d5fa5d14311253cc34b0afbd62b474d388e Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Mon, 28 Aug 2023 09:13:39 +0200 Subject: [PATCH 2/5] Add tests. --- docs/source/changes.md | 2 ++ src/_pytask/collect_command.py | 19 ++++++++++++++----- tests/test_collect_command.py | 26 ++++++++++++++++++++++++++ tests/test_execute.py | 24 +++++++++++++++++++++++- 4 files changed, 65 insertions(+), 6 deletions(-) diff --git a/docs/source/changes.md b/docs/source/changes.md index 7a535918..7320f560 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -23,6 +23,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and `@pytask.mark.produces`. - {pull}`402` replaces ABCs with protocols allowing for more flexibility for users implementing their own nodes. +- {pull}`404` allows to use function returns to define task products. +- {pull}`405` allows to match function returns to node annotations with prefix trees. ## 0.3.2 - 2023-06-07 diff --git a/src/_pytask/collect_command.py b/src/_pytask/collect_command.py index a8b3fcf7..39165bee 100644 --- a/src/_pytask/collect_command.py +++ b/src/_pytask/collect_command.py @@ -127,7 +127,9 @@ def _find_common_ancestor_of_all_nodes( all_paths.extend( x.path for x in tree_leaves(task.depends_on) if isinstance(x, PPathNode) ) - all_paths.extend(x.path for x in tree_leaves(task.produces)) + all_paths.extend( + x.path for x in tree_leaves(task.produces) if isinstance(x, PPathNode) + ) common_ancestor = find_common_ancestor(*all_paths, *paths) @@ -219,10 +221,17 @@ def _print_collected_tasks( task_branch.add(Text.assemble(FILE_ICON, "")) - for node in sorted(tree_leaves(task.produces), key=lambda x: x.path): - reduced_node_name = str(relative_to(node.path, common_ancestor)) - url_style = create_url_style_for_path(node.path, editor_url_scheme) - text = Text(reduced_node_name, style=url_style) + for node in sorted( + tree_leaves(task.produces), key=lambda x: getattr(x, "path", x.name) + ): + if isinstance(node, PPathNode): + reduced_node_name = str(relative_to(node.path, common_ancestor)) + url_style = create_url_style_for_path( + node.path, editor_url_scheme + ) + text = Text(reduced_node_name, style=url_style) + else: + text = Text(node.name) task_branch.add(Text.assemble(FILE_ICON, "")) console.print(tree) diff --git a/tests/test_collect_command.py b/tests/test_collect_command.py index 41e22849..0fbc23e5 100644 --- a/tests/test_collect_command.py +++ b/tests/test_collect_command.py @@ -610,3 +610,29 @@ def task_example( result = runner.invoke(cli, ["collect", "--nodes", tmp_path.as_posix()]) assert result.exit_code == ExitCode.OK assert "node-name" in result.output + + +@pytest.mark.end_to_end() +def test_more_nested_pytree_and_python_node_as_return(runner, tmp_path): + source = """ + from pathlib import Path + from typing import Any + from typing_extensions import Annotated + from pytask import PythonNode + + nodes = [ + PythonNode(name="dict"), + (PythonNode(name="tuple1"), PythonNode(name="tuple2")), + PythonNode(name="int") + ] + + def task_example() -> Annotated[dict[str, str], nodes]: + return [{"first": "a", "second": "b"}, (1, 2), 1] + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + result = runner.invoke(cli, ["collect", "--nodes", tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + assert "dict" in result.output + assert "tuple1" in result.output + assert "tuple2" in result.output + assert "int" in result.output diff --git a/tests/test_execute.py b/tests/test_execute.py index 83117ecb..a395afe8 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -669,7 +669,7 @@ def task_example() -> Annotated[str, (node1, node2)]: @pytest.mark.end_to_end() -def test_pytree_and_python_node(runner, tmp_path): +def test_pytree_and_python_node_as_return(runner, tmp_path): source = """ from pathlib import Path from typing import Any @@ -682,3 +682,25 @@ def task_example() -> Annotated[dict[str, str], PythonNode(name="result")]: tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) result = runner.invoke(cli, [tmp_path.as_posix()]) assert result.exit_code == ExitCode.OK + + +@pytest.mark.end_to_end() +def test_more_nested_pytree_and_python_node_as_return(runner, tmp_path): + source = """ + from pathlib import Path + from typing import Any + from typing_extensions import Annotated + from pytask import PythonNode + + nodes = [ + PythonNode(name="dict"), + (PythonNode(name="tuple1"), PythonNode(name="tuple2")), + PythonNode(name="int") + ] + + def task_example() -> Annotated[dict[str, str], nodes]: + return [{"first": "a", "second": "b"}, (1, 2), 1] + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + result = runner.invoke(cli, [tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK From c315f48568e8d0829fa7fb3313afea6c5689bb0d Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Mon, 28 Aug 2023 09:32:22 +0200 Subject: [PATCH 3/5] more tests. --- src/_pytask/nodes.py | 1 + tests/test_tree_util.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/src/_pytask/nodes.py b/src/_pytask/nodes.py index cb5e18e6..111759f5 100644 --- a/src/_pytask/nodes.py +++ b/src/_pytask/nodes.py @@ -81,6 +81,7 @@ def execute(self, **kwargs: Any) -> None: if "return" in self.produces: structure_out = tree_structure(out) structure_return = tree_structure(self.produces["return"]) + # strict must be false. if not structure_return.is_prefix(structure_out, strict=False): raise ValueError( "The structure of the return annotation is not a subtree of the " diff --git a/tests/test_tree_util.py b/tests/test_tree_util.py index 34341e9b..2f39b1a7 100644 --- a/tests/test_tree_util.py +++ b/tests/test_tree_util.py @@ -6,6 +6,7 @@ import pytest from _pytask.outcomes import ExitCode from _pytask.tree_util import tree_map +from _pytask.tree_util import tree_structure from pytask import cli from pytask import main @@ -80,3 +81,20 @@ def task_example(produces): assert "0." in result.output assert "Size of Products" in result.output assert "86 bytes" in result.output + + +@pytest.mark.unit() +@pytest.mark.parametrize( + ("prefix_tree", "full_tree", "strict", "expected"), + [ + # This is why strict cannot be true when parsing function returns. + (1, 1, True, False), + (1, 1, False, True), + ({"a": 1, "b": 1}, {"a": 1, "b": {"c": 1, "d": 1}}, False, True), + ({"a": 1, "b": 1}, {"a": 1, "b": {"c": 1, "d": 1}}, True, True), + ], +) +def test_is_prefix(prefix_tree, full_tree, strict, expected): + prefix_structure = tree_structure(prefix_tree) + full_tree_structure = tree_structure(full_tree) + assert prefix_structure.is_prefix(full_tree_structure, strict=strict) is expected From 91290f9598a434c0bd67be34d4786187ad82724f Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Mon, 28 Aug 2023 09:33:54 +0200 Subject: [PATCH 4/5] Fix. --- src/_pytask/nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/_pytask/nodes.py b/src/_pytask/nodes.py index 111759f5..8bc550c8 100644 --- a/src/_pytask/nodes.py +++ b/src/_pytask/nodes.py @@ -81,7 +81,7 @@ def execute(self, **kwargs: Any) -> None: if "return" in self.produces: structure_out = tree_structure(out) structure_return = tree_structure(self.produces["return"]) - # strict must be false. + # strict must be false when none is leaf. if not structure_return.is_prefix(structure_out, strict=False): raise ValueError( "The structure of the return annotation is not a subtree of the " From 1a2145b16ebcacfe24142306618c284ae277b40a Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Mon, 28 Aug 2023 12:08:12 +0200 Subject: [PATCH 5/5] Fix. --- tests/test_collect_command.py | 3 ++- tests/test_execute.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/test_collect_command.py b/tests/test_collect_command.py index 0fbc23e5..3007221d 100644 --- a/tests/test_collect_command.py +++ b/tests/test_collect_command.py @@ -619,6 +619,7 @@ def test_more_nested_pytree_and_python_node_as_return(runner, tmp_path): from typing import Any from typing_extensions import Annotated from pytask import PythonNode + from typing import Dict nodes = [ PythonNode(name="dict"), @@ -626,7 +627,7 @@ def test_more_nested_pytree_and_python_node_as_return(runner, tmp_path): PythonNode(name="int") ] - def task_example() -> Annotated[dict[str, str], nodes]: + def task_example() -> Annotated[Dict[str, str], nodes]: return [{"first": "a", "second": "b"}, (1, 2), 1] """ tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) diff --git a/tests/test_execute.py b/tests/test_execute.py index a395afe8..442c471a 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -675,8 +675,9 @@ def test_pytree_and_python_node_as_return(runner, tmp_path): from typing import Any from typing_extensions import Annotated from pytask import PythonNode + from typing import Dict - def task_example() -> Annotated[dict[str, str], PythonNode(name="result")]: + def task_example() -> Annotated[Dict[str, str], PythonNode(name="result")]: return {"first": "a", "second": "b"} """ tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) @@ -691,6 +692,7 @@ def test_more_nested_pytree_and_python_node_as_return(runner, tmp_path): from typing import Any from typing_extensions import Annotated from pytask import PythonNode + from typing import Dict nodes = [ PythonNode(name="dict"), @@ -698,7 +700,7 @@ def test_more_nested_pytree_and_python_node_as_return(runner, tmp_path): PythonNode(name="int") ] - def task_example() -> Annotated[dict[str, str], nodes]: + def task_example() -> Annotated[Dict[str, str], nodes]: return [{"first": "a", "second": "b"}, (1, 2), 1] """ tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))