|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | import functools
|
| 4 | +import sys |
4 | 5 | from itertools import repeat
|
5 | 6 | from textwrap import dedent
|
6 | 7 | from typing import TYPE_CHECKING, Callable, Tuple
|
@@ -202,10 +203,15 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
|
202 | 203 | ],
|
203 | 204 | )
|
204 | 205 | )
|
| 206 | + func_with_error_context = _handle_errors_with_path_context( |
| 207 | + node_of_first_tree.path |
| 208 | + )(func) |
205 | 209 |
|
206 | 210 | # Now we can call func on the data in this particular set of corresponding nodes
|
207 | 211 | results = (
|
208 |
| - func(*node_args_as_datasets, **node_kwargs_as_datasets) |
| 212 | + func_with_error_context( |
| 213 | + *node_args_as_datasets, **node_kwargs_as_datasets |
| 214 | + ) |
209 | 215 | if node_of_first_tree.has_data
|
210 | 216 | else None
|
211 | 217 | )
|
@@ -251,6 +257,34 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
|
251 | 257 | return _map_over_subtree
|
252 | 258 |
|
253 | 259 |
|
| 260 | +def _handle_errors_with_path_context(path): |
| 261 | + """Wraps given function so that if it fails it also raises path to node on which it failed.""" |
| 262 | + |
| 263 | + def decorator(func): |
| 264 | + def wrapper(*args, **kwargs): |
| 265 | + try: |
| 266 | + return func(*args, **kwargs) |
| 267 | + except Exception as e: |
| 268 | + if sys.version_info >= (3, 11): |
| 269 | + # Add the context information to the error message |
| 270 | + e.add_note( |
| 271 | + f"Raised whilst mapping function over node with path {path}" |
| 272 | + ) |
| 273 | + raise |
| 274 | + |
| 275 | + return wrapper |
| 276 | + |
| 277 | + return decorator |
| 278 | + |
| 279 | + |
| 280 | +def add_note(err: BaseException, msg: str) -> None: |
| 281 | + # TODO: remove once python 3.10 can be dropped |
| 282 | + if sys.version_info < (3, 11): |
| 283 | + err.__notes__ = getattr(err, "__notes__", []) + [msg] |
| 284 | + else: |
| 285 | + err.add_note(msg) |
| 286 | + |
| 287 | + |
254 | 288 | def _check_single_set_return_values(path_to_node, obj):
|
255 | 289 | """Check types returned from single evaluation of func, and return number of return values received from func."""
|
256 | 290 | if isinstance(obj, (Dataset, DataArray)):
|
|
0 commit comments