Skip to content

Commit 70e5de9

Browse files
authored
Merge 75e3659 into 37d244e
2 parents 37d244e + 75e3659 commit 70e5de9

13 files changed

+350
-34
lines changed

src/_pytask/collect_utils.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,7 @@ def _find_args_with_node_annotation(func: Callable[..., Any]) -> dict[str, Node]
308308

309309
args_with_node_annotation = {}
310310
for name, meta in metas.items():
311-
annot = [
312-
i for i in meta if not isinstance(i, ProductType) and isinstance(i, Node)
313-
]
311+
annot = [i for i in meta if not isinstance(i, ProductType)]
314312
if len(annot) >= 2: # noqa: PLR2004
315313
raise ValueError(
316314
f"Parameter {name!r} has multiple node annotations although only one "
@@ -360,6 +358,8 @@ def parse_products_from_task_function(
360358
has_produces_decorator = False
361359
has_produces_argument = False
362360
has_annotation = False
361+
has_return = False
362+
has_task_decorator = False
363363
out = {}
364364

365365
# Parse products from decorators.
@@ -415,8 +415,45 @@ def parse_products_from_task_function(
415415
)
416416
out = {parameter_name: collected_products}
417417

418+
if "return" in parameters_with_node_annot:
419+
has_return = True
420+
collected_products = tree_map_with_path(
421+
lambda p, x: _collect_product(
422+
session,
423+
path,
424+
name,
425+
NodeInfo("return", p, x),
426+
is_string_allowed=False,
427+
),
428+
parameters_with_node_annot["return"],
429+
)
430+
out = {"return": collected_products}
431+
432+
task_produces = obj.pytask_meta.produces if hasattr(obj, "pytask_meta") else None
433+
if task_produces:
434+
has_task_decorator = True
435+
collected_products = tree_map_with_path(
436+
lambda p, x: _collect_product(
437+
session,
438+
path,
439+
name,
440+
NodeInfo("return", p, x),
441+
is_string_allowed=False,
442+
),
443+
task_produces,
444+
)
445+
out = {"return": collected_products}
446+
418447
if (
419-
sum((has_produces_decorator, has_produces_argument, has_annotation))
448+
sum(
449+
(
450+
has_produces_decorator,
451+
has_produces_argument,
452+
has_annotation,
453+
has_return,
454+
has_task_decorator,
455+
)
456+
)
420457
>= 2 # noqa: PLR2004
421458
):
422459
raise NodeNotCollectedError(_ERROR_MULTIPLE_PRODUCT_DEFINITIONS)
@@ -571,5 +608,5 @@ def _evolve_instance(x: Any, instance_from_annot: Node | None) -> Any:
571608
if not instance_from_annot:
572609
return x
573610

574-
instance_from_annot.value = x
611+
instance_from_annot.from_annot(x)
575612
return instance_from_annot

src/_pytask/execute.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,11 @@ def pytask_execute_task(session: Session, task: Task) -> bool:
147147

148148
kwargs = {}
149149
for name, value in task.depends_on.items():
150-
kwargs[name] = tree_map(lambda x: x.value, value)
150+
kwargs[name] = tree_map(lambda x: x.load(), value)
151151

152152
for name, value in task.produces.items():
153153
if name in parameters:
154-
kwargs[name] = tree_map(lambda x: x.value, value)
154+
kwargs[name] = tree_map(lambda x: x.load(), value)
155155

156156
task.execute(**kwargs)
157157
return True

src/_pytask/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import NamedTuple
66
from typing import TYPE_CHECKING
77

8+
from _pytask.tree_util import PyTree
89
from attrs import define
910
from attrs import field
1011

@@ -24,6 +25,8 @@ class CollectionMetadata:
2425
"""Contains the markers of the function."""
2526
name: str | None = None
2627
"""The name of the task function."""
28+
produces: PyTree[Any] = None
29+
"""Definition of products to handle returns."""
2730

2831

2932
class NodeInfo(NamedTuple):

src/_pytask/node_protocols.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@ class MetaNode(Protocol):
1515
"""The name of node that must be unique."""
1616

1717
@abstractmethod
18-
def state(self) -> Any:
18+
def state(self) -> str | None:
19+
"""Return the state of the node.
20+
21+
The state can be something like a hash or a last modified timestamp. If the node
22+
does not exist, you can also return ``None``.
23+
24+
"""
1925
...
2026

2127

@@ -25,9 +31,36 @@ class Node(MetaNode, Protocol):
2531

2632
value: Any
2733

34+
def load(self) -> Any:
35+
"""Return the value of the node that will be injected into the task."""
36+
...
37+
38+
def save(self, value: Any) -> Any:
39+
"""Save the value that was returned from a task."""
40+
...
41+
42+
def from_annot(self, value: Any) -> Any:
43+
"""Complete the node by setting the value from an default argument.
44+
45+
Use it, if you want to add information on how a node handles an argument while
46+
keeping the type of the value unrelated to pytask.
47+
48+
.. codeblock: python
49+
50+
def task_example(value: Annotated[Any, PythonNode(hash=True)], produces):
51+
...
52+
53+
54+
"""
55+
...
56+
2857

2958
@runtime_checkable
3059
class PPathNode(Node, Protocol):
31-
"""Nodes with paths."""
60+
"""Nodes with paths.
61+
62+
Nodes with paths receive special handling when it comes to printing their names.
63+
64+
"""
3265

3366
path: Path

src/_pytask/nodes.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@
66
from pathlib import Path
77
from typing import Any
88
from typing import Callable
9+
from typing import NoReturn
910
from typing import TYPE_CHECKING
1011

1112
from _pytask.node_protocols import MetaNode
1213
from _pytask.node_protocols import Node
1314
from _pytask.tree_util import PyTree
15+
from _pytask.tree_util import tree_leaves
16+
from _pytask.tree_util import tree_structure
1417
from attrs import define
1518
from attrs import field
1619

@@ -28,6 +31,7 @@ class ProductType:
2831

2932

3033
Product = ProductType()
34+
"""ProductType: A singleton to mark products in annotations."""
3135

3236

3337
@define(kw_only=True)
@@ -64,6 +68,7 @@ def __attrs_post_init__(self: Task) -> None:
6468
self.short_name = self.name
6569

6670
def state(self, hash: bool = False) -> str | None: # noqa: A002
71+
"""Return the state of the node."""
6772
if hash and self.path.exists():
6873
return hashlib.sha256(self.path.read_bytes()).hexdigest()
6974
if not hash and self.path.exists():
@@ -72,7 +77,22 @@ def state(self, hash: bool = False) -> str | None: # noqa: A002
7277

7378
def execute(self, **kwargs: Any) -> None:
7479
"""Execute the task."""
75-
self.function(**kwargs)
80+
out = self.function(**kwargs)
81+
82+
if "return" in self.produces:
83+
structure_out = tree_structure(out)
84+
structure_return = tree_structure(self.produces["return"])
85+
if not structure_out == structure_return:
86+
raise ValueError(
87+
"The structure of the function return does not match the structure "
88+
f"of the return annotation.\n\nFunction return: {structure_out}\n\n"
89+
f"Return annotation: {structure_return}"
90+
)
91+
92+
for out_, return_ in zip(
93+
tree_leaves(out), tree_leaves(self.produces["return"])
94+
):
95+
return_.save(out_)
7696

7797
def add_report_section(self, when: str, key: str, content: str) -> None:
7898
"""Add sections which will be displayed in report like stdout or stderr."""
@@ -86,24 +106,20 @@ class PathNode(Node):
86106

87107
name: str = ""
88108
"""Name of the node which makes it identifiable in the DAG."""
89-
_value: Path | None = None
109+
value: Path | None = None
90110
"""Value passed to the decorator which can be requested inside the function."""
91111

92112
@property
93113
def path(self) -> Path:
94114
return self.value
95115

96-
@property
97-
def value(self) -> Path:
98-
return self._value
99-
100-
@value.setter
101-
def value(self, value: Path) -> None:
116+
def from_annot(self, value: Path) -> None:
117+
"""Set path and if other attributes are not set, set sensible defaults."""
102118
if not isinstance(value, Path):
103119
raise TypeError("'value' must be a 'pathlib.Path'.")
104120
if not self.name:
105121
self.name = value.as_posix()
106-
self._value = value
122+
self.value = value
107123

108124
@classmethod
109125
@functools.lru_cache
@@ -127,18 +143,45 @@ def state(self) -> str | None:
127143
return str(self.path.stat().st_mtime)
128144
return None
129145

146+
def load(self) -> Path:
147+
"""Load the value."""
148+
return self.value
149+
150+
def save(self, value: bytes | str) -> None:
151+
"""Save strings or bytes to file."""
152+
if isinstance(value, str):
153+
self.path.write_text(value)
154+
elif isinstance(value, bytes):
155+
self.path.write_bytes(value)
156+
else:
157+
raise TypeError(
158+
f"'PathNode' can only save 'str' and 'bytes', not {type(value)}"
159+
)
160+
130161

131162
@define(kw_only=True)
132163
class PythonNode(Node):
133164
"""The class for a node which is a Python object."""
134165

135166
name: str = ""
136167
"""Name of the node."""
137-
value: Any | None = None
168+
value: Any = None
138169
"""Value of the node."""
139170
hash: bool = False # noqa: A003
140171
"""Whether the value should be hashed to determine the state."""
141172

173+
def load(self) -> Any:
174+
"""Load the value."""
175+
return self.value
176+
177+
def save(self, value: Any) -> NoReturn:
178+
"""Save the value."""
179+
raise NotImplementedError
180+
181+
def from_annot(self, value: Any) -> None:
182+
"""Set the value from a function annotation."""
183+
self.value = value
184+
142185
def state(self) -> str | None:
143186
"""Calculate state of the node.
144187

src/_pytask/task_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from _pytask.mark import Mark
1212
from _pytask.models import CollectionMetadata
1313
from _pytask.shared import find_duplicates
14+
from _pytask.tree_util import PyTree
1415

1516

1617
__all__ = ["parse_keyword_arguments_from_signature_defaults"]
@@ -31,6 +32,7 @@ def task(
3132
*,
3233
id: str | None = None, # noqa: A002
3334
kwargs: dict[Any, Any] | None = None,
35+
produces: PyTree[Any] = None,
3436
) -> Callable[..., None]:
3537
"""Parse inputs of the ``@pytask.mark.task`` decorator.
3638
@@ -43,13 +45,15 @@ def task(
4345
4446
Parameters
4547
----------
46-
name : str | None
48+
name
4749
The name of the task.
48-
id : str | None
50+
id
4951
An id for the task if it is part of a parametrization.
50-
kwargs : dict[Any, Any] | None
52+
kwargs
5153
A dictionary containing keyword arguments which are passed to the task when it
5254
is executed.
55+
produces
56+
Definition of products to handle returns.
5357
5458
"""
5559

@@ -77,12 +81,14 @@ def wrapper(func: Callable[..., Any]) -> None:
7781
unwrapped.pytask_meta.kwargs = parsed_kwargs
7882
unwrapped.pytask_meta.markers.append(Mark("task", (), {}))
7983
unwrapped.pytask_meta.id_ = id
84+
unwrapped.pytask_meta.produces = produces
8085
else:
8186
unwrapped.pytask_meta = CollectionMetadata(
8287
name=parsed_name,
8388
kwargs=parsed_kwargs,
8489
markers=[Mark("task", (), {})],
8590
id_=id,
91+
produces=produces,
8692
)
8793

8894
COLLECTED_TASKS[path].append(unwrapped)

src/_pytask/tree_util.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
from optree import tree_leaves as _optree_tree_leaves
1010
from optree import tree_map as _optree_tree_map
1111
from optree import tree_map_with_path as _optree_tree_map_with_path
12+
from optree import tree_structure as _optree_tree_structure
1213

1314

1415
__all__ = [
1516
"tree_leaves",
1617
"tree_map",
1718
"tree_map_with_path",
19+
"tree_structure",
1820
"PyTree",
1921
"TREE_UTIL_LIB_DIRECTORY",
2022
]
@@ -29,3 +31,6 @@
2931
tree_map_with_path = functools.partial(
3032
_optree_tree_map_with_path, none_is_leaf=True, namespace="pytask"
3133
)
34+
tree_structure = functools.partial(
35+
_optree_tree_structure, none_is_leaf=True, namespace="pytask"
36+
)

tests/test_collect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def test_pytask_collect_node(session, path, node_info, expected):
172172
if result is None:
173173
assert result is expected
174174
else:
175-
assert str(result.value) == str(expected)
175+
assert str(result.load()) == str(expected)
176176

177177

178178
@pytest.mark.unit()

0 commit comments

Comments
 (0)