Skip to content

Commit 01f0fc8

Browse files
committed
Allow to inject task returns via @pytask.mark.task(produces=...).
1 parent d3bd55a commit 01f0fc8

File tree

5 files changed

+102
-25
lines changed

5 files changed

+102
-25
lines changed

src/_pytask/collect_utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ def parse_products_from_task_function(
359359
has_produces_argument = False
360360
has_annotation = False
361361
has_return = False
362+
has_task_decorator = False
362363
out = {}
363364

364365
# Parse products from decorators.
@@ -428,8 +429,31 @@ def parse_products_from_task_function(
428429
)
429430
out = {"return": collected_products}
430431

432+
task_produces = obj.pytask_meta.produces if hasattr(obj, "pytask_meta") else None
433+
if task_produces:
434+
has_task_decorator = True
435+
collected_products = tree_map_with_path(
436+
lambda p, x: _collect_product(
437+
session,
438+
path,
439+
name,
440+
NodeInfo("return", p, x),
441+
is_string_allowed=False,
442+
),
443+
task_produces,
444+
)
445+
out = {"return": collected_products}
446+
431447
if (
432-
sum((has_produces_decorator, has_produces_argument, has_annotation, has_return))
448+
sum(
449+
(
450+
has_produces_decorator,
451+
has_produces_argument,
452+
has_annotation,
453+
has_return,
454+
has_task_decorator,
455+
)
456+
)
433457
>= 2 # noqa: PLR2004
434458
):
435459
raise NodeNotCollectedError(_ERROR_MULTIPLE_PRODUCT_DEFINITIONS)

src/_pytask/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import NamedTuple
66
from typing import TYPE_CHECKING
77

8+
from _pytask.tree_util import PyTree
89
from attrs import define
910
from attrs import field
1011

@@ -24,6 +25,8 @@ class CollectionMetadata:
2425
"""Contains the markers of the function."""
2526
name: str | None = None
2627
"""The name of the task function."""
28+
produces: PyTree[Any] = None
29+
"""Definition of products to handle returns."""
2730

2831

2932
class NodeInfo(NamedTuple):

src/_pytask/task_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from _pytask.mark import Mark
1212
from _pytask.models import CollectionMetadata
1313
from _pytask.shared import find_duplicates
14+
from _pytask.tree_util import PyTree
1415

1516

1617
__all__ = ["parse_keyword_arguments_from_signature_defaults"]
@@ -31,6 +32,7 @@ def task(
3132
*,
3233
id: str | None = None, # noqa: A002
3334
kwargs: dict[Any, Any] | None = None,
35+
produces: PyTree[Any] = None,
3436
) -> Callable[..., None]:
3537
"""Parse inputs of the ``@pytask.mark.task`` decorator.
3638
@@ -43,13 +45,15 @@ def task(
4345
4446
Parameters
4547
----------
46-
name : str | None
48+
name
4749
The name of the task.
48-
id : str | None
50+
id
4951
An id for the task if it is part of a parametrization.
50-
kwargs : dict[Any, Any] | None
52+
kwargs
5153
A dictionary containing keyword arguments which are passed to the task when it
5254
is executed.
55+
produces
56+
Definition of products to handle returns.
5357
5458
"""
5559

@@ -77,12 +81,14 @@ def wrapper(func: Callable[..., Any]) -> None:
7781
unwrapped.pytask_meta.kwargs = parsed_kwargs
7882
unwrapped.pytask_meta.markers.append(Mark("task", (), {}))
7983
unwrapped.pytask_meta.id_ = id
84+
unwrapped.pytask_meta.produces = produces
8085
else:
8186
unwrapped.pytask_meta = CollectionMetadata(
8287
name=parsed_name,
8388
kwargs=parsed_kwargs,
8489
markers=[Mark("task", (), {})],
8590
id_=id,
91+
produces=produces,
8692
)
8793

8894
COLLECTED_TASKS[path].append(unwrapped)

tests/test_execute.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,27 @@ def task_example() -> Annotated[str, node]:
585585
assert tmp_path.joinpath("file.txt").read_text() == "Hello, World!"
586586

587587

588+
@pytest.mark.end_to_end()
589+
def test_return_with_tuple_pathnode_annotation_as_return(runner, tmp_path):
590+
source = """
591+
from pathlib import Path
592+
from typing import Any
593+
from typing_extensions import Annotated
594+
from pytask import PathNode
595+
596+
node1 = PathNode.from_path(Path(__file__).parent.joinpath("file1.txt"))
597+
node2 = PathNode.from_path(Path(__file__).parent.joinpath("file2.txt"))
598+
599+
def task_example() -> Annotated[str, (node1, node2)]:
600+
return "Hello,", "World!"
601+
"""
602+
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
603+
result = runner.invoke(cli, [tmp_path.as_posix()])
604+
assert result.exit_code == ExitCode.OK
605+
assert tmp_path.joinpath("file1.txt").read_text() == "Hello,"
606+
assert tmp_path.joinpath("file2.txt").read_text() == "World!"
607+
608+
588609
@pytest.mark.end_to_end()
589610
def test_return_with_custom_type_annotation_as_return(runner, tmp_path):
590611
source = """
@@ -626,24 +647,3 @@ def task_example() -> Annotated[int, node]:
626647

627648
data = pickle.loads(tmp_path.joinpath("data.pkl").read_bytes()) # noqa: S301
628649
assert data == 1
629-
630-
631-
@pytest.mark.end_to_end()
632-
def test_return_with_tuple_pathnode_annotation_as_return(runner, tmp_path):
633-
source = """
634-
from pathlib import Path
635-
from typing import Any
636-
from typing_extensions import Annotated
637-
from pytask import PathNode
638-
639-
node1 = PathNode.from_path(Path(__file__).parent.joinpath("file1.txt"))
640-
node2 = PathNode.from_path(Path(__file__).parent.joinpath("file2.txt"))
641-
642-
def task_example() -> Annotated[str, (node1, node2)]:
643-
return "Hello,", "World!"
644-
"""
645-
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
646-
result = runner.invoke(cli, [tmp_path.as_posix()])
647-
assert result.exit_code == ExitCode.OK
648-
assert tmp_path.joinpath("file1.txt").read_text() == "Hello,"
649-
assert tmp_path.joinpath("file2.txt").read_text() == "World!"

tests/test_task.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,3 +511,47 @@ def task_example(
511511
assert result.exit_code == ExitCode.OK
512512
assert tmp_path.joinpath("out.txt").read_text() == "Hello world!"
513513
assert not tmp_path.joinpath("not_used_out.txt").exists()
514+
515+
516+
@pytest.mark.end_to_end()
517+
def test_return_with_task_decorator(runner, tmp_path):
518+
source = """
519+
from pathlib import Path
520+
from typing import Any
521+
from typing_extensions import Annotated
522+
from pytask import PathNode
523+
import pytask
524+
525+
node = PathNode.from_path(Path(__file__).parent.joinpath("file.txt"))
526+
527+
@pytask.mark.task(produces=node)
528+
def task_example():
529+
return "Hello, World!"
530+
"""
531+
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
532+
result = runner.invoke(cli, [tmp_path.as_posix()])
533+
assert result.exit_code == ExitCode.OK
534+
assert tmp_path.joinpath("file.txt").read_text() == "Hello, World!"
535+
536+
537+
@pytest.mark.end_to_end()
538+
def test_return_with_tuple_and_task_decorator(runner, tmp_path):
539+
source = """
540+
from pathlib import Path
541+
from typing import Any
542+
from typing_extensions import Annotated
543+
from pytask import PathNode
544+
import pytask
545+
546+
node1 = PathNode.from_path(Path(__file__).parent.joinpath("file1.txt"))
547+
node2 = PathNode.from_path(Path(__file__).parent.joinpath("file2.txt"))
548+
549+
@pytask.mark.task(produces=(node1, node2))
550+
def task_example():
551+
return "Hello,", "World!"
552+
"""
553+
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
554+
result = runner.invoke(cli, [tmp_path.as_posix()])
555+
assert result.exit_code == ExitCode.OK
556+
assert tmp_path.joinpath("file1.txt").read_text() == "Hello,"
557+
assert tmp_path.joinpath("file2.txt").read_text() == "World!"

0 commit comments

Comments
 (0)