Skip to content

Commit 75e3659

Browse files
committed
More tests.
1 parent 01f0fc8 commit 75e3659

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

src/_pytask/nodes.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class ProductType:
3131

3232

3333
Product = ProductType()
34+
"""ProductType: A singleton to mark products in annotations."""
3435

3536

3637
@define(kw_only=True)
@@ -84,7 +85,8 @@ def execute(self, **kwargs: Any) -> None:
8485
if not structure_out == structure_return:
8586
raise ValueError(
8687
"The structure of the function return does not match the structure "
87-
"of the return annotation."
88+
f"of the return annotation.\n\nFunction return: {structure_out}\n\n"
89+
f"Return annotation: {structure_return}"
8890
)
8991

9092
for out_, return_ in zip(
@@ -142,6 +144,7 @@ def state(self) -> str | None:
142144
return None
143145

144146
def load(self) -> Path:
147+
"""Load the value."""
145148
return self.value
146149

147150
def save(self, value: bytes | str) -> None:
@@ -168,12 +171,15 @@ class PythonNode(Node):
168171
"""Whether the value should be hashed to determine the state."""
169172

170173
def load(self) -> Any:
174+
"""Load the value."""
171175
return self.value
172176

173177
def save(self, value: Any) -> NoReturn:
178+
"""Save the value."""
174179
raise NotImplementedError
175180

176181
def from_annot(self, value: Any) -> None:
182+
"""Set the value from a function annotation."""
177183
self.value = value
178184

179185
def state(self) -> str | None:

tests/test_execute.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -555,9 +555,7 @@ def test_return_with_path_annotation_as_return(runner, tmp_path):
555555
from typing_extensions import Annotated
556556
from pytask import PathNode
557557
558-
path = Path(__file__).parent.joinpath("file.txt")
559-
560-
def task_example() -> Annotated[str, path]:
558+
def task_example() -> Annotated[str, Path("file.txt")]:
561559
return "Hello, World!"
562560
"""
563561
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
@@ -647,3 +645,24 @@ def task_example() -> Annotated[int, node]:
647645

648646
data = pickle.loads(tmp_path.joinpath("data.pkl").read_bytes()) # noqa: S301
649647
assert data == 1
648+
649+
650+
@pytest.mark.end_to_end()
651+
def test_error_when_return_pytree_mismatch(runner, tmp_path):
652+
source = """
653+
from pathlib import Path
654+
from typing import Any
655+
from typing_extensions import Annotated
656+
from pytask import PathNode
657+
658+
node1 = PathNode.from_path(Path(__file__).parent.joinpath("file1.txt"))
659+
node2 = PathNode.from_path(Path(__file__).parent.joinpath("file2.txt"))
660+
661+
def task_example() -> Annotated[str, (node1, node2)]:
662+
return "Hello,"
663+
"""
664+
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
665+
result = runner.invoke(cli, [tmp_path.as_posix()])
666+
assert result.exit_code == ExitCode.FAILED
667+
assert "Function return: PyTreeSpec(*, NoneIsLeaf)" in result.output
668+
assert "Return annotation: PyTreeSpec((*, *), NoneIsLeaf)" in result.output

0 commit comments

Comments
 (0)