Skip to content

Add support for NamedTuple and attrs classes in @pytask.mark.task(kwargs=...). #397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
- {pull}`395` refactors all occurrences of pybaum to {mod}`_pytask.tree_util`.
- {pull}`396` replaces pybaum with optree and adds paths to the name of
{class}`pytask.PythonNode`'s allowing for better hashing.
- {class}`397` adds support for {class}`typing.NamedTuple` and attrs classes in
`@pytask.mark.task(kwargs=...)`.

## 0.3.2 - 2023-06-07

Expand Down
129 changes: 79 additions & 50 deletions src/_pytask/collect_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""This module provides utility functions for :mod:`_pytask.collect`."""
from __future__ import annotations

import inspect
import itertools
import uuid
from pathlib import Path
Expand Down Expand Up @@ -79,9 +78,15 @@ def parse_nodes(
session: Session, path: Path, name: str, obj: Any, parser: Callable[..., Any]
) -> Any:
"""Parse nodes from object."""
arg_name = parser.__name__
objects = _extract_nodes_from_function_markers(obj, parser)
nodes = _convert_objects_to_node_dictionary(objects, parser.__name__)
nodes = tree_map(lambda x: _collect_old_dependencies(session, path, name, x), nodes)
nodes = _convert_objects_to_node_dictionary(objects, arg_name)
nodes = tree_map(
lambda x: _collect_decorator_nodes(
session, path, name, NodeInfo(arg_name, (), x)
),
nodes,
)
return nodes


Expand Down Expand Up @@ -211,27 +216,61 @@ def _merge_dictionaries(list_of_dicts: list[dict[Any, Any]]) -> dict[Any, Any]:
return out


def parse_dependencies_from_task_function(
_ERROR_MULTIPLE_DEPENDENCY_DEFINITIONS = """The task uses multiple ways to define \
dependencies. Dependencies should be defined with either

- '@pytask.mark.depends_on' and a 'depends_on' function argument.
- as default value for the function argument 'depends_on'.

Use only one of the two ways!

Hint: You do not need to use 'depends_on' since pytask v0.4. Every function argument \
that is not a product is treated as a dependency. Read more about dependencies in the \
documentation: https://tinyurl.com/yrezszr4.
"""


def parse_dependencies_from_task_function( # noqa: C901
session: Session, path: Path, name: str, obj: Any
) -> dict[str, Any]:
"""Parse dependencies from task function."""
has_depends_on_decorator = False
has_depends_on_argument = False
dependencies = {}

if has_mark(obj, "depends_on"):
has_depends_on_decorator = True
nodes = parse_nodes(session, path, name, obj, depends_on)
return {"depends_on": nodes}
dependencies["depends_on"] = nodes

task_kwargs = obj.pytask_meta.kwargs if hasattr(obj, "pytask_meta") else {}
signature_defaults = parse_keyword_arguments_from_signature_defaults(obj)
kwargs = {**signature_defaults, **task_kwargs}
kwargs.pop("produces", None)

# Parse products from task decorated with @task and that uses produces.
if "depends_on" in kwargs:
has_depends_on_argument = True
dependencies["depends_on"] = tree_map(
lambda x: _collect_decorator_nodes(
session, path, name, NodeInfo(arg_name="depends_on", path=(), value=x)
),
kwargs["depends_on"],
)

if has_depends_on_decorator and has_depends_on_argument:
raise NodeNotCollectedError(_ERROR_MULTIPLE_DEPENDENCY_DEFINITIONS)

parameters_with_product_annot = _find_args_with_product_annotation(obj)
parameters_with_node_annot = _find_args_with_node_annotation(obj)

dependencies = {}
for parameter_name, value in kwargs.items():
if parameter_name in parameters_with_product_annot:
continue

if parameter_name == "depends_on":
continue

if parameter_name in parameters_with_node_annot:

def _evolve(x: Any) -> Any:
Expand Down Expand Up @@ -316,18 +355,23 @@ def parse_products_from_task_function(

"""
has_produces_decorator = False
has_task_decorator = False
has_signature_default = False
has_produces_argument = False
has_annotation = False
out = {}

# Parse products from decorators.
if has_mark(obj, "produces"):
has_produces_decorator = True
nodes = parse_nodes(session, path, name, obj, produces)
out = {"produces": nodes}

task_kwargs = obj.pytask_meta.kwargs if hasattr(obj, "pytask_meta") else {}
if "produces" in task_kwargs:
signature_defaults = parse_keyword_arguments_from_signature_defaults(obj)
kwargs = {**signature_defaults, **task_kwargs}

# Parse products from task decorated with @task and that uses produces.
if "produces" in kwargs:
has_produces_argument = True
collected_products = tree_map_with_path(
lambda p, x: _collect_product(
session,
Expand All @@ -336,35 +380,15 @@ def parse_products_from_task_function(
NodeInfo(arg_name="produces", path=p, value=x),
is_string_allowed=True,
),
task_kwargs["produces"],
kwargs["produces"],
)
out = {"produces": collected_products}

parameters = inspect.signature(obj).parameters

if not has_mark(obj, "task") and "produces" in parameters:
parameter = parameters["produces"]
if parameter.default is not parameter.empty:
has_signature_default = True
# Use _collect_new_node to not collect strings.
collected_products = tree_map_with_path(
lambda p, x: _collect_product(
session,
path,
name,
NodeInfo(arg_name="produces", path=p, value=x),
is_string_allowed=False,
),
parameter.default,
)
out = {"produces": collected_products}

parameters_with_product_annot = _find_args_with_product_annotation(obj)
if parameters_with_product_annot:
has_annotation = True
for parameter_name in parameters_with_product_annot:
parameter = parameters[parameter_name]
if parameter.default is not parameter.empty:
if parameter_name in kwargs:
# Use _collect_new_node to not collect strings.
collected_products = tree_map_with_path(
lambda p, x: _collect_product(
Expand All @@ -374,19 +398,12 @@ def parse_products_from_task_function(
NodeInfo(parameter_name, p, x), # noqa: B023
is_string_allowed=False,
),
parameter.default,
kwargs[parameter_name],
)
out = {parameter_name: collected_products}

if (
sum(
(
has_produces_decorator,
has_task_decorator,
has_signature_default,
has_annotation,
)
)
sum((has_produces_decorator, has_produces_argument, has_annotation))
>= 2 # noqa: PLR2004
):
raise NodeNotCollectedError(_ERROR_MULTIPLE_PRODUCT_DEFINITIONS)
Expand All @@ -412,8 +429,15 @@ def _find_args_with_product_annotation(func: Callable[..., Any]) -> list[str]:
return args_with_product_annot


def _collect_old_dependencies(
session: Session, path: Path, name: str, node: str | Path
_ERROR_WRONG_TYPE_DECORATOR = """'@pytask.mark.depends_on', '@pytask.mark.produces', \
and their function arguments can only accept values of type 'str' and 'pathlib.Path' \
or the same values nested in tuples, lists, and dictionaries. Here, {node} has type \
{node_type}.
"""


def _collect_decorator_nodes(
session: Session, path: Path, name: str, node_info: NodeInfo
) -> dict[str, MetaNode]:
"""Collect nodes for a task.

Expand All @@ -423,22 +447,26 @@ def _collect_old_dependencies(
If the node could not collected.

"""
node = node_info.value

if not isinstance(node, (str, Path)):
raise ValueError(
"'@pytask.mark.depends_on' and '@pytask.mark.produces' can only accept "
"values of type 'str' and 'pathlib.Path' or the same values nested in "
f"tuples, lists, and dictionaries. Here, {node} has type {type(node)}."
raise NodeNotCollectedError(
_ERROR_WRONG_TYPE_DECORATOR.format(node=node, node_type=type(node))
)

if isinstance(node, str):
node = Path(node)
node_info = node_info._replace(value=node)

collected_node = session.hook.pytask_collect_node(
session=session, path=path, node_info=NodeInfo("produces", (), node)
session=session, path=path, node_info=node_info
)
if collected_node is None:
kind = {"depends_on": "dependency", "produces": "product"}.get(
node_info.arg_name
)
raise NodeNotCollectedError(
f"{node!r} cannot be parsed as a dependency for task {name!r} in {path!r}."
f"{node!r} cannot be parsed as a {kind} for task {name!r} in {path!r}."
)

return collected_node
Expand All @@ -458,7 +486,7 @@ def _collect_dependencies(
node = node_info.value

collected_node = session.hook.pytask_collect_node(
session=session, path=path, node_info=node_info, node=node
session=session, path=path, node_info=node_info
)
if collected_node is None:
raise NodeNotCollectedError(
Expand Down Expand Up @@ -505,9 +533,10 @@ def _collect_product(

if isinstance(node, str):
node = Path(node)
node_info = node_info._replace(value=node)

collected_node = session.hook.pytask_collect_node(
session=session, path=path, node_info=node_info, node=node
session=session, path=path, node_info=node_info
)
if collected_node is None:
raise NodeNotCollectedError(
Expand Down
5 changes: 5 additions & 0 deletions src/_pytask/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@
def _database_url_callback(
ctx: Context, name: str, value: str | None # noqa: ARG001
) -> URL:
"""Check the url for the database."""
# Since sqlalchemy v2.0.19, we need to shortcircuit here.
if value is None:
return None

try:
return make_url(value)
except ArgumentError:
Expand Down
38 changes: 22 additions & 16 deletions src/_pytask/task_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any
from typing import Callable

import attrs
from _pytask.mark import Mark
from _pytask.models import CollectionMetadata
from _pytask.shared import find_duplicates
Expand Down Expand Up @@ -113,15 +114,6 @@ def parse_collected_tasks_with_task_marker(
else:
collected_tasks[name] = [i[1] for i in parsed_tasks if i[0] == name][0]

# TODO: Remove when parsing dependencies and products from all arguments is
# implemented.
for task in collected_tasks.values():
meta = task.pytask_meta # type: ignore[attr-defined]
for marker_name in ("depends_on", "produces"):
if marker_name in meta.kwargs:
value = meta.kwargs.pop(marker_name)
meta.markers.append(Mark(marker_name, (value,), {}))

return collected_tasks


Expand All @@ -143,24 +135,38 @@ def _parse_tasks_with_preliminary_names(

def _parse_task(task: Callable[..., Any]) -> tuple[str, Callable[..., Any]]:
"""Parse a single task."""
name = task.pytask_meta.name # type: ignore[attr-defined]
if name is None and task.__name__ == "_":
meta = task.pytask_meta # type: ignore[attr-defined]

if meta.name is None and task.__name__ == "_":
raise ValueError(
"A task function either needs 'name' passed by the ``@pytask.mark.task`` "
"decorator or the function name of the task function must not be '_'."
)

parsed_name = task.__name__ if name is None else name
parsed_name = task.__name__ if meta.name is None else meta.name
parsed_kwargs = _parse_task_kwargs(meta.kwargs)

signature_kwargs = parse_keyword_arguments_from_signature_defaults(task)
task.pytask_meta.kwargs = { # type: ignore[attr-defined]
**task.pytask_meta.kwargs, # type: ignore[attr-defined]
**signature_kwargs,
}
meta.kwargs = {**signature_kwargs, **parsed_kwargs}

return parsed_name, task


def _parse_task_kwargs(kwargs: Any) -> dict[str, Any]:
"""Parse task kwargs."""
if isinstance(kwargs, dict):
return kwargs
# Handle namedtuples.
if callable(getattr(kwargs, "_asdict", None)):
return kwargs._asdict()
if attrs.has(type(kwargs)):
return attrs.asdict(kwargs)
raise ValueError(
"'@pytask.mark.task(kwargs=...) needs to be a dictionary, namedtuple or an "
"instance of an attrs class."
)


def parse_keyword_arguments_from_signature_defaults(
task: Callable[..., Any]
) -> dict[str, Any]:
Expand Down
Loading