Skip to content

Commit 5e5d9bc

Browse files
authored
Merge 1a2145b into dbe1fce
2 parents dbe1fce + 1a2145b commit 5e5d9bc

9 files changed

+129
-27
lines changed

docs/source/changes.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
2323
`@pytask.mark.produces`.
2424
- {pull}`402` replaces ABCs with protocols allowing for more flexibility for users
2525
implementing their own nodes.
26+
- {pull}`404` allows to use function returns to define task products.
27+
- {pull}`405` allows to match function returns to node annotations with prefix trees.
2628

2729
## 0.3.2 - 2023-06-07
2830

src/_pytask/collect_command.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ def _find_common_ancestor_of_all_nodes(
127127
all_paths.extend(
128128
x.path for x in tree_leaves(task.depends_on) if isinstance(x, PPathNode)
129129
)
130-
all_paths.extend(x.path for x in tree_leaves(task.produces))
130+
all_paths.extend(
131+
x.path for x in tree_leaves(task.produces) if isinstance(x, PPathNode)
132+
)
131133

132134
common_ancestor = find_common_ancestor(*all_paths, *paths)
133135

@@ -219,10 +221,17 @@ def _print_collected_tasks(
219221

220222
task_branch.add(Text.assemble(FILE_ICON, "<Dependency ", text, ">"))
221223

222-
for node in sorted(tree_leaves(task.produces), key=lambda x: x.path):
223-
reduced_node_name = str(relative_to(node.path, common_ancestor))
224-
url_style = create_url_style_for_path(node.path, editor_url_scheme)
225-
text = Text(reduced_node_name, style=url_style)
224+
for node in sorted(
225+
tree_leaves(task.produces), key=lambda x: getattr(x, "path", x.name)
226+
):
227+
if isinstance(node, PPathNode):
228+
reduced_node_name = str(relative_to(node.path, common_ancestor))
229+
url_style = create_url_style_for_path(
230+
node.path, editor_url_scheme
231+
)
232+
text = Text(reduced_node_name, style=url_style)
233+
else:
234+
text = Text(node.name)
226235
task_branch.add(Text.assemble(FILE_ICON, "<Product ", text, ">"))
227236

228237
console.print(tree)

src/_pytask/collect_utils.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from _pytask.mark_utils import remove_marks
1919
from _pytask.models import NodeInfo
2020
from _pytask.node_protocols import Node
21-
from _pytask.node_protocols import PPathNode
2221
from _pytask.nodes import ProductType
2322
from _pytask.nodes import PythonNode
2423
from _pytask.shared import find_duplicates
@@ -580,15 +579,9 @@ def _collect_product(
580579
"type 'str' and 'pathlib.Path' or the same values nested in "
581580
f"tuples, lists, and dictionaries. Here, {node} has type {type(node)}."
582581
)
583-
# The parameter defaults only support Path objects.
584-
if not isinstance(node, (Path, PPathNode)) and not is_string_allowed:
585-
raise ValueError(
586-
"If you declare products with 'Annotated[..., Product]', only values of "
587-
"type 'pathlib.Path' optionally nested in tuples, lists, and "
588-
f"dictionaries are allowed. Here, {node!r} has type {type(node)}."
589-
)
590582

591-
if isinstance(node, str):
583+
# If we encounter a string and it is allowed, convert it to a path.
584+
if isinstance(node, str) and is_string_allowed:
592585
node = Path(node)
593586
node_info = node_info._replace(value=node)
594587

src/_pytask/nodes.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from pathlib import Path
77
from typing import Any
88
from typing import Callable
9-
from typing import NoReturn
109
from typing import TYPE_CHECKING
1110

1211
from _pytask.node_protocols import MetaNode
@@ -82,17 +81,18 @@ def execute(self, **kwargs: Any) -> None:
8281
if "return" in self.produces:
8382
structure_out = tree_structure(out)
8483
structure_return = tree_structure(self.produces["return"])
85-
if not structure_out == structure_return:
84+
# strict must be false when none is leaf.
85+
if not structure_return.is_prefix(structure_out, strict=False):
8686
raise ValueError(
87-
"The structure of the function return does not match the structure "
88-
f"of the return annotation.\n\nFunction return: {structure_out}\n\n"
89-
f"Return annotation: {structure_return}"
87+
"The structure of the return annotation is not a subtree of the "
88+
"structure of the function return.\n\nFunction return: "
89+
f"{structure_out}\n\nReturn annotation: {structure_return}"
9090
)
9191

92-
for out_, return_ in zip(
93-
tree_leaves(out), tree_leaves(self.produces["return"])
94-
):
95-
return_.save(out_)
92+
nodes = tree_leaves(self.produces["return"])
93+
values = structure_return.flatten_up_to(out)
94+
for node, value in zip(nodes, values):
95+
node.save(value)
9696

9797
def add_report_section(self, when: str, key: str, content: str) -> None:
9898
"""Add sections which will be displayed in report like stdout or stderr."""
@@ -174,9 +174,9 @@ def load(self) -> Any:
174174
"""Load the value."""
175175
return self.value
176176

177-
def save(self, value: Any) -> NoReturn:
177+
def save(self, value: Any) -> None:
178178
"""Save the value."""
179-
raise NotImplementedError
179+
self.value = value
180180

181181
def from_annot(self, value: Any) -> None:
182182
"""Set the value from a function annotation."""

src/_pytask/tree_util.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@
33

44
import functools
55
from pathlib import Path
6+
from typing import Any
67

78
import optree
89
from optree import PyTree
10+
from optree import tree_flatten_with_path as _optree_tree_flatten_with_path
911
from optree import tree_leaves as _optree_tree_leaves
1012
from optree import tree_map as _optree_tree_map
1113
from optree import tree_map_with_path as _optree_tree_map_with_path
1214
from optree import tree_structure as _optree_tree_structure
1315

1416

1517
__all__ = [
18+
"tree_flatten_with_path",
1619
"tree_leaves",
1720
"tree_map",
1821
"tree_map_with_path",
@@ -34,3 +37,13 @@
3437
tree_structure = functools.partial(
3538
_optree_tree_structure, none_is_leaf=True, namespace="pytask"
3639
)
40+
tree_flatten_with_path = functools.partial(
41+
_optree_tree_flatten_with_path, none_is_leaf=True, namespace="pytask"
42+
)
43+
44+
45+
def tree_index(path: tuple[Any, ...], tree: PyTree) -> Any:
46+
"""Index a tree with a path."""
47+
if not path:
48+
return tree
49+
return tree_index(path[1:], tree[path[0]])

tests/test_collect.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ def task_write_text(produces="out.txt"):
340340

341341
@pytest.mark.end_to_end()
342342
def test_collect_string_product_raises_error_with_annotation(runner, tmp_path):
343+
"""The string is not converted to a path."""
343344
source = """
344345
from pytask import Product
345346
from typing_extensions import Annotated
@@ -349,8 +350,7 @@ def task_write_text(out: Annotated[str, Product] = "out.txt") -> None:
349350
"""
350351
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
351352
result = runner.invoke(cli, [tmp_path.as_posix()])
352-
assert result.exit_code == ExitCode.COLLECTION_FAILED
353-
assert "If you declare products with 'Annotated[..., Product]'" in result.output
353+
assert result.exit_code == ExitCode.FAILED
354354

355355

356356
@pytest.mark.end_to_end()

tests/test_collect_command.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,3 +610,30 @@ def task_example(
610610
result = runner.invoke(cli, ["collect", "--nodes", tmp_path.as_posix()])
611611
assert result.exit_code == ExitCode.OK
612612
assert "node-name" in result.output
613+
614+
615+
@pytest.mark.end_to_end()
616+
def test_more_nested_pytree_and_python_node_as_return(runner, tmp_path):
617+
source = """
618+
from pathlib import Path
619+
from typing import Any
620+
from typing_extensions import Annotated
621+
from pytask import PythonNode
622+
from typing import Dict
623+
624+
nodes = [
625+
PythonNode(name="dict"),
626+
(PythonNode(name="tuple1"), PythonNode(name="tuple2")),
627+
PythonNode(name="int")
628+
]
629+
630+
def task_example() -> Annotated[Dict[str, str], nodes]:
631+
return [{"first": "a", "second": "b"}, (1, 2), 1]
632+
"""
633+
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
634+
result = runner.invoke(cli, ["collect", "--nodes", tmp_path.as_posix()])
635+
assert result.exit_code == ExitCode.OK
636+
assert "dict" in result.output
637+
assert "tuple1" in result.output
638+
assert "tuple2" in result.output
639+
assert "int" in result.output

tests/test_execute.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,3 +666,43 @@ def task_example() -> Annotated[str, (node1, node2)]:
666666
assert result.exit_code == ExitCode.FAILED
667667
assert "Function return: PyTreeSpec(*, NoneIsLeaf)" in result.output
668668
assert "Return annotation: PyTreeSpec((*, *), NoneIsLeaf)" in result.output
669+
670+
671+
@pytest.mark.end_to_end()
672+
def test_pytree_and_python_node_as_return(runner, tmp_path):
673+
source = """
674+
from pathlib import Path
675+
from typing import Any
676+
from typing_extensions import Annotated
677+
from pytask import PythonNode
678+
from typing import Dict
679+
680+
def task_example() -> Annotated[Dict[str, str], PythonNode(name="result")]:
681+
return {"first": "a", "second": "b"}
682+
"""
683+
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
684+
result = runner.invoke(cli, [tmp_path.as_posix()])
685+
assert result.exit_code == ExitCode.OK
686+
687+
688+
@pytest.mark.end_to_end()
689+
def test_more_nested_pytree_and_python_node_as_return(runner, tmp_path):
690+
source = """
691+
from pathlib import Path
692+
from typing import Any
693+
from typing_extensions import Annotated
694+
from pytask import PythonNode
695+
from typing import Dict
696+
697+
nodes = [
698+
PythonNode(name="dict"),
699+
(PythonNode(name="tuple1"), PythonNode(name="tuple2")),
700+
PythonNode(name="int")
701+
]
702+
703+
def task_example() -> Annotated[Dict[str, str], nodes]:
704+
return [{"first": "a", "second": "b"}, (1, 2), 1]
705+
"""
706+
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
707+
result = runner.invoke(cli, [tmp_path.as_posix()])
708+
assert result.exit_code == ExitCode.OK

tests/test_tree_util.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77
from _pytask.outcomes import ExitCode
88
from _pytask.tree_util import tree_map
9+
from _pytask.tree_util import tree_structure
910
from pytask import cli
1011
from pytask import main
1112

@@ -80,3 +81,20 @@ def task_example(produces):
8081
assert "0." in result.output
8182
assert "Size of Products" in result.output
8283
assert "86 bytes" in result.output
84+
85+
86+
@pytest.mark.unit()
87+
@pytest.mark.parametrize(
88+
("prefix_tree", "full_tree", "strict", "expected"),
89+
[
90+
# This is why strict cannot be true when parsing function returns.
91+
(1, 1, True, False),
92+
(1, 1, False, True),
93+
({"a": 1, "b": 1}, {"a": 1, "b": {"c": 1, "d": 1}}, False, True),
94+
({"a": 1, "b": 1}, {"a": 1, "b": {"c": 1, "d": 1}}, True, True),
95+
],
96+
)
97+
def test_is_prefix(prefix_tree, full_tree, strict, expected):
98+
prefix_structure = tree_structure(prefix_tree)
99+
full_tree_structure = tree_structure(full_tree)
100+
assert prefix_structure.is_prefix(full_tree_structure, strict=strict) is expected

0 commit comments

Comments
 (0)