Skip to content

Commit 9beab9a

Browse files
Improved error messages for P2P shuffling (dask#7979)
Co-authored-by: Lawrence Mitchell <[email protected]>
1 parent 8a5ae5f commit 9beab9a

File tree

3 files changed

+38
-25
lines changed

3 files changed

+38
-25
lines changed

distributed/shuffle/_scheduler_extension.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ class ShuffleState(abc.ABC):
4545
def to_msg(self) -> dict[str, Any]:
4646
"""Transform the shuffle state into a JSON-serializable message"""
4747

48+
def __str__(self) -> str:
49+
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
50+
4851

4952
@dataclass
5053
class DataFrameShuffleState(ShuffleState):
@@ -119,15 +122,24 @@ def shuffle_ids(self) -> set[ShuffleId]:
119122

120123
async def barrier(self, id: ShuffleId, run_id: int) -> None:
121124
shuffle = self.states[id]
125+
assert shuffle.run_id == run_id, f"{run_id=} does not match {shuffle}"
122126
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
123127
await self.scheduler.broadcast(
124128
msg=msg, workers=list(shuffle.participating_workers)
125129
)
126130

127131
def restrict_task(self, id: ShuffleId, run_id: int, key: str, worker: str) -> dict:
128132
shuffle = self.states[id]
129-
if shuffle.run_id != run_id:
130-
return {"status": "error", "message": "Stale shuffle"}
133+
if shuffle.run_id > run_id:
134+
return {
135+
"status": "error",
136+
"message": f"Request stale, expected {run_id=} for {shuffle}",
137+
}
138+
elif shuffle.run_id < run_id:
139+
return {
140+
"status": "error",
141+
"message": f"Request invalid, expected {run_id=} for {shuffle}",
142+
}
131143
ts = self.scheduler.tasks[key]
132144
self._set_restriction(ts, worker)
133145
return {"status": "OK"}
@@ -298,9 +310,7 @@ def remove_worker(self, scheduler: Scheduler, worker: str) -> None:
298310
for shuffle_id, shuffle in self.states.items():
299311
if worker not in shuffle.participating_workers:
300312
continue
301-
exception = RuntimeError(
302-
f"Worker {worker} left during active shuffle {shuffle_id}"
303-
)
313+
exception = RuntimeError(f"Worker {worker} left during active {shuffle}")
304314
self.erred_shuffles[shuffle_id] = exception
305315
self._fail_on_workers(shuffle, str(exception))
306316

@@ -335,7 +345,7 @@ def transition(
335345
shuffle = self.states[shuffle_id]
336346
except KeyError:
337347
return
338-
self._fail_on_workers(shuffle, message=f"Shuffle {shuffle_id} forgotten")
348+
self._fail_on_workers(shuffle, message=f"{shuffle} forgotten")
339349
self._clean_on_scheduler(shuffle_id)
340350

341351
def _fail_on_workers(self, shuffle: ShuffleState, message: str) -> None:

distributed/shuffle/_worker_extension.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,10 @@ def __init__(
9999
self._closed_event = asyncio.Event()
100100

101101
def __repr__(self) -> str:
102-
return f"<{self.__class__.__name__} {self.id}[{self.run_id}] on {self.local_address}>"
102+
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
103+
104+
def __str__(self) -> str:
105+
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
103106

104107
def __hash__(self) -> int:
105108
return self.run_id
@@ -162,9 +165,7 @@ def raise_if_closed(self) -> None:
162165
if self.closed:
163166
if self._exception:
164167
raise self._exception
165-
raise ShuffleClosedError(
166-
f"Shuffle {self.id} has been closed on {self.local_address}"
167-
)
168+
raise ShuffleClosedError(f"{self} has already been closed")
168169

169170
async def inputs_done(self) -> None:
170171
self.raise_if_closed()
@@ -346,7 +347,7 @@ async def _receive(self, data: list[tuple[ArrayRechunkShardID, bytes]]) -> None:
346347
async def add_partition(self, data: np.ndarray, partition_id: NDIndex) -> int:
347348
self.raise_if_closed()
348349
if self.transferred:
349-
raise RuntimeError(f"Cannot add more partitions to shuffle {self}")
350+
raise RuntimeError(f"Cannot add more partitions to {self}")
350351

351352
def _() -> dict[str, list[tuple[ArrayRechunkShardID, bytes]]]:
352353
"""Return a mapping of worker addresses to a list of tuples of shard IDs
@@ -511,7 +512,7 @@ def _repartition_buffers(self, data: list[bytes]) -> dict[NDIndex, list[bytes]]:
511512
async def add_partition(self, data: pd.DataFrame, partition_id: int) -> int:
512513
self.raise_if_closed()
513514
if self.transferred:
514-
raise RuntimeError(f"Cannot add more partitions to shuffle {self}")
515+
raise RuntimeError(f"Cannot add more partitions to {self}")
515516

516517
def _() -> dict[str, list[tuple[int, bytes]]]:
517518
out = split_by_worker(
@@ -586,6 +587,12 @@ def __init__(self, worker: Worker) -> None:
586587
self.closed = False
587588
self._executor = ThreadPoolExecutor(self.worker.state.nthreads)
588589

590+
def __str__(self) -> str:
591+
return f"ShuffleWorkerExtension on {self.worker.address}"
592+
593+
def __repr__(self) -> str:
594+
return f"<ShuffleWorkerExtension, worker={self.worker.address_safe!r}, closed={self.closed}>"
595+
589596
# Handlers
590597
##########
591598
# NOTE: handlers are not threadsafe, but they're called from async comms, so that's okay
@@ -695,11 +702,11 @@ async def _get_shuffle_run(
695702
shuffle = await self._refresh_shuffle(
696703
shuffle_id=shuffle_id,
697704
)
698-
if run_id < shuffle.run_id:
699-
raise RuntimeError("Stale shuffle")
700-
elif run_id > shuffle.run_id:
701-
# This should never happen
702-
raise RuntimeError("Invalid shuffle state")
705+
706+
if shuffle.run_id > run_id:
707+
raise RuntimeError(f"{run_id=} stale, got {shuffle}")
708+
elif shuffle.run_id < run_id:
709+
raise RuntimeError(f"{run_id=} invalid, got {shuffle}")
703710

704711
if shuffle._exception:
705712
raise shuffle._exception
@@ -729,9 +736,7 @@ async def _get_or_create_shuffle(
729736
)
730737

731738
if self.closed:
732-
raise ShuffleClosedError(
733-
f"{self.__class__.__name__} already closed on {self.worker.address}"
734-
)
739+
raise ShuffleClosedError(f"{self} has already been closed")
735740
if shuffle._exception:
736741
raise shuffle._exception
737742
return shuffle
@@ -790,9 +795,7 @@ async def _refresh_shuffle(
790795
assert result["status"] == "OK"
791796

792797
if self.closed:
793-
raise ShuffleClosedError(
794-
f"{self.__class__.__name__} already closed on {self.worker.address}"
795-
)
798+
raise ShuffleClosedError(f"{self} has already been closed")
796799
if shuffle_id in self.shuffles:
797800
existing = self.shuffles[shuffle_id]
798801
if existing.run_id >= result["run_id"]:

distributed/shuffle/tests/test_shuffle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,7 +1601,7 @@ async def test_shuffle_run_consistency(c, s, a):
16011601

16021602
# This should never occur, but fetching an ID larger than the ID available on
16031603
# the scheduler should result in an error.
1604-
with pytest.raises(RuntimeError, match="Invalid shuffle state"):
1604+
with pytest.raises(RuntimeError, match="invalid"):
16051605
await worker_ext._get_shuffle_run(shuffle_id, shuffle_dict["run_id"] + 1)
16061606

16071607
# Finish first execution
@@ -1628,7 +1628,7 @@ async def test_shuffle_run_consistency(c, s, a):
16281628
assert await worker_ext._get_shuffle_run(shuffle_id, new_shuffle_dict["run_id"])
16291629

16301630
# Fetching a stale run from a worker aware of the new run raises an error
1631-
with pytest.raises(RuntimeError, match="Stale shuffle"):
1631+
with pytest.raises(RuntimeError, match="stale"):
16321632
await worker_ext._get_shuffle_run(shuffle_id, shuffle_dict["run_id"])
16331633

16341634
worker_ext.block_barrier.set()

0 commit comments

Comments
 (0)