Skip to content

Commit 97eb505

Browse files
author
Henry Peteet
authored
Restart crashed Unity environments (#5553)
### Change(s) Update the SubprocessEnvManager to restart workers when the underlying Unity environments crash. When a worker receives an ENV_EXITED signal it will now: 1. Record all failures coming through the step queue and drop all other messages. 2. Purge any pending trajectories as they may belong to a crashed worker or be corrupted. 3. Restart all failed workers (up to a configurable limit) This behavior can be limited via a rate limit, max lifetime limit, or both. The configuration options for both are shown below with their default values. ⚠️ Each of these options applies to a single environment, if num_envs > 1 then the limit will apply separately to each replica (num_envs = 2 will spawn 2 Unity environments which can each be restarted 10 times). ```yaml env_settings: # Can restart 10 times over the lifetime of the experiment. max_lifetime_restarts: 10 # Rate limit of 1 failure per 60s restarts_rate_limit_n: 1 restarts_rate_limit_period_s: 60 ``` They can of course be passed via CLI arguments as well ```bash --max-lifetime-restarts --restarts-rate-limit-n --restarts-rate-limit-period-s ``` ### Disabling this feature * Rate limiting can be turned off by setting `--restarts-rate-limit-n=-1` * Lifetime limiting can be turned off by setting `--max-lifetime-restarts=-1`
1 parent 37a6aa2 commit 97eb505

File tree

7 files changed

+217
-8
lines changed

7 files changed

+217
-8
lines changed

com.unity.ml-agents/CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ and this project adheres to
2222
- Set gym version in gym-unity to gym release 0.20.0
2323
- Added support for having `beta`, `epsilon`, and `learning rate` on separate schedules (affects only PPO and POCA). (#5538)
2424

25+
- Changed default behavior to restart crashed Unity environments rather than exiting. (#5553)
26+
- Rate & lifetime limits on this are configurable via 3 new yaml options
27+
1. env_params.max_lifetime_restarts (--max-lifetime-restarts) [default=10]
28+
2. env_params.restarts_rate_limit_n (--restarts-rate-limit-n) [default=1]
29+
3. env_params.restarts_rate_limit_period_s (--restarts-rate-limit-period-s) [default=60]
2530
### Bug Fixes
2631

2732
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)

docs/Training-ML-Agents.md

+3
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,9 @@ env_settings:
207207
base_port: 5005
208208
num_envs: 1
209209
seed: -1
210+
max_lifetime_restarts: 10
211+
restarts_rate_limit_n: 1
212+
restarts_rate_limit_period_s: 60
210213
```
211214
212215
#### Engine settings

ml-agents-envs/mlagents_envs/rpc_communicator.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@ def create_server(self):
5858

5959
try:
6060
# Establish communication grpc
61-
self.server = grpc.server(ThreadPoolExecutor(max_workers=10))
61+
self.server = grpc.server(
62+
thread_pool=ThreadPoolExecutor(max_workers=10),
63+
options=(("grpc.so_reuseport", 1),),
64+
)
6265
self.unity_to_external = UnityToExternalServicerImplementation()
6366
add_UnityToExternalProtoServicer_to_server(
6467
self.unity_to_external, self.server

ml-agents/mlagents/trainers/cli_utils.py

+20
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,26 @@ def _create_parser() -> argparse.ArgumentParser:
177177
"passed to the executable.",
178178
action=DetectDefault,
179179
)
180+
argparser.add_argument(
181+
"--max-lifetime-restarts",
182+
default=10,
183+
help="The max number of times a single Unity executable can crash over its lifetime before ml-agents exits. "
184+
"Can be set to -1 if no limit is desired.",
185+
action=DetectDefault,
186+
)
187+
argparser.add_argument(
188+
"--restarts-rate-limit-n",
189+
default=1,
190+
help="The maximum number of times a single Unity executable can crash over a period of time (period set in "
191+
"restarts-rate-limit-period-s). Can be set to -1 to not use rate limiting with restarts.",
192+
action=DetectDefault,
193+
)
194+
argparser.add_argument(
195+
"--restarts-rate-limit-period-s",
196+
default=60,
197+
help="The period of time --restarts-rate-limit-n applies to.",
198+
action=DetectDefault,
199+
)
180200
argparser.add_argument(
181201
"--torch",
182202
default=False,

ml-agents/mlagents/trainers/settings.py

+5
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,11 @@ class EnvironmentSettings:
823823
base_port: int = parser.get_default("base_port")
824824
num_envs: int = attr.ib(default=parser.get_default("num_envs"))
825825
seed: int = parser.get_default("seed")
826+
max_lifetime_restarts: int = parser.get_default("max_lifetime_restarts")
827+
restarts_rate_limit_n: int = parser.get_default("restarts_rate_limit_n")
828+
restarts_rate_limit_period_s: int = parser.get_default(
829+
"restarts_rate_limit_period_s"
830+
)
826831

827832
@num_envs.validator
828833
def validate_num_envs(self, attribute, value):

ml-agents/mlagents/trainers/subprocess_env_manager.py

+117-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import datetime
12
from typing import Dict, NamedTuple, List, Any, Optional, Callable, Set
23
import cloudpickle
34
import enum
@@ -251,6 +252,14 @@ def __init__(
251252
self.env_workers: List[UnityEnvWorker] = []
252253
self.step_queue: Queue = Queue()
253254
self.workers_alive = 0
255+
self.env_factory = env_factory
256+
self.run_options = run_options
257+
self.env_parameters: Optional[Dict] = None
258+
# Each worker is correlated with a list of times they restarted within the last time period.
259+
self.recent_restart_timestamps: List[List[datetime.datetime]] = [
260+
[] for _ in range(n_env)
261+
]
262+
self.restart_counts: List[int] = [0] * n_env
254263
for worker_idx in range(n_env):
255264
self.env_workers.append(
256265
self.create_worker(
@@ -293,6 +302,105 @@ def _queue_steps(self) -> None:
293302
env_worker.send(EnvironmentCommand.STEP, env_action_info)
294303
env_worker.waiting = True
295304

305+
def _restart_failed_workers(self, first_failure: EnvironmentResponse) -> None:
306+
if first_failure.cmd != EnvironmentCommand.ENV_EXITED:
307+
return
308+
# Drain the step queue to make sure all workers are paused and we have found all concurrent errors.
309+
# Pausing all training is needed since we need to reset all pending training steps as they could be corrupted.
310+
other_failures: Dict[int, Exception] = self._drain_step_queue()
311+
# TODO: Once we use python 3.9 switch to using the | operator to combine dicts.
312+
failures: Dict[int, Exception] = {
313+
**{first_failure.worker_id: first_failure.payload},
314+
**other_failures,
315+
}
316+
for worker_id, ex in failures.items():
317+
self._assert_worker_can_restart(worker_id, ex)
318+
logger.warning(f"Restarting worker[{worker_id}] after '{ex}'")
319+
self.recent_restart_timestamps[worker_id].append(datetime.datetime.now())
320+
self.restart_counts[worker_id] += 1
321+
self.env_workers[worker_id] = self.create_worker(
322+
worker_id, self.step_queue, self.env_factory, self.run_options
323+
)
324+
# The restarts were successful, clear all the existing training trajectories so we don't use corrupted or
325+
# outdated data.
326+
self.reset(self.env_parameters)
327+
328+
def _drain_step_queue(self) -> Dict[int, Exception]:
329+
"""
330+
Drains all steps out of the step queue and returns all exceptions from crashed workers.
331+
This will effectively pause all workers so that they won't do anything until _queue_steps is called.
332+
"""
333+
all_failures = {}
334+
workers_still_pending = {w.worker_id for w in self.env_workers if w.waiting}
335+
deadline = datetime.datetime.now() + datetime.timedelta(minutes=1)
336+
while workers_still_pending and deadline > datetime.datetime.now():
337+
try:
338+
while True:
339+
step: EnvironmentResponse = self.step_queue.get_nowait()
340+
if step.cmd == EnvironmentCommand.ENV_EXITED:
341+
workers_still_pending.add(step.worker_id)
342+
all_failures[step.worker_id] = step.payload
343+
else:
344+
workers_still_pending.remove(step.worker_id)
345+
self.env_workers[step.worker_id].waiting = False
346+
except EmptyQueueException:
347+
pass
348+
if deadline < datetime.datetime.now():
349+
still_waiting = {w.worker_id for w in self.env_workers if w.waiting}
350+
raise TimeoutError(f"Workers {still_waiting} stuck in waiting state")
351+
return all_failures
352+
353+
def _assert_worker_can_restart(self, worker_id: int, exception: Exception) -> None:
354+
"""
355+
Checks if we can recover from an exception from a worker.
356+
If the restart limit is exceeded it will raise a UnityCommunicationException.
357+
If the exception is not recoverable it re-raises the exception.
358+
"""
359+
if (
360+
isinstance(exception, UnityCommunicationException)
361+
or isinstance(exception, UnityTimeOutException)
362+
or isinstance(exception, UnityEnvironmentException)
363+
or isinstance(exception, UnityCommunicatorStoppedException)
364+
):
365+
if self._worker_has_restart_quota(worker_id):
366+
return
367+
else:
368+
logger.error(
369+
f"Worker {worker_id} exceeded the allowed number of restarts."
370+
)
371+
raise exception
372+
raise exception
373+
374+
def _worker_has_restart_quota(self, worker_id: int) -> bool:
375+
self._drop_old_restart_timestamps(worker_id)
376+
max_lifetime_restarts = self.run_options.env_settings.max_lifetime_restarts
377+
max_limit_check = (
378+
max_lifetime_restarts == -1
379+
or self.restart_counts[worker_id] < max_lifetime_restarts
380+
)
381+
382+
rate_limit_n = self.run_options.env_settings.restarts_rate_limit_n
383+
rate_limit_check = (
384+
rate_limit_n == -1
385+
or len(self.recent_restart_timestamps[worker_id]) < rate_limit_n
386+
)
387+
388+
return rate_limit_check and max_limit_check
389+
390+
def _drop_old_restart_timestamps(self, worker_id: int) -> None:
391+
"""
392+
Drops environment restart timestamps that are outside of the current window.
393+
"""
394+
395+
def _filter(t: datetime.datetime) -> bool:
396+
return t > datetime.datetime.now() - datetime.timedelta(
397+
seconds=self.run_options.env_settings.restarts_rate_limit_period_s
398+
)
399+
400+
self.recent_restart_timestamps[worker_id] = list(
401+
filter(_filter, self.recent_restart_timestamps[worker_id])
402+
)
403+
296404
def _step(self) -> List[EnvironmentStep]:
297405
# Queue steps for any workers which aren't in the "waiting" state.
298406
self._queue_steps()
@@ -306,15 +414,18 @@ def _step(self) -> List[EnvironmentStep]:
306414
while True:
307415
step: EnvironmentResponse = self.step_queue.get_nowait()
308416
if step.cmd == EnvironmentCommand.ENV_EXITED:
309-
env_exception: Exception = step.payload
310-
raise env_exception
311-
self.env_workers[step.worker_id].waiting = False
312-
if step.worker_id not in step_workers:
417+
# If even one env exits try to restart all envs that failed.
418+
self._restart_failed_workers(step)
419+
# Clear state and restart this function.
420+
worker_steps.clear()
421+
step_workers.clear()
422+
self._queue_steps()
423+
elif step.worker_id not in step_workers:
424+
self.env_workers[step.worker_id].waiting = False
313425
worker_steps.append(step)
314426
step_workers.add(step.worker_id)
315427
except EmptyQueueException:
316428
pass
317-
318429
step_infos = self._postprocess_steps(worker_steps)
319430
return step_infos
320431

@@ -339,6 +450,7 @@ def set_env_parameters(self, config: Dict = None) -> None:
339450
EnvironmentParametersSidehannel for each worker.
340451
:param config: Dict of environment parameter keys and values
341452
"""
453+
self.env_parameters = config
342454
for ew in self.env_workers:
343455
ew.send(EnvironmentCommand.ENVIRONMENT_PARAMETERS, config)
344456

ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py

+63-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from unittest import mock
2-
from unittest.mock import Mock, MagicMock
2+
from unittest.mock import Mock, MagicMock, call, ANY
33
import unittest
44
import pytest
55
from queue import Empty as EmptyQueue
@@ -14,7 +14,10 @@
1414
from mlagents.trainers.env_manager import EnvironmentStep
1515
from mlagents_envs.base_env import BaseEnv
1616
from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod
17-
from mlagents_envs.exception import UnityEnvironmentException
17+
from mlagents_envs.exception import (
18+
UnityEnvironmentException,
19+
UnityCommunicationException,
20+
)
1821
from mlagents.trainers.tests.simple_test_envs import (
1922
SimpleEnvironment,
2023
UnexpectedExceptionEnvironment,
@@ -153,6 +156,64 @@ def test_step_takes_steps_for_all_non_waiting_envs(self, mock_create_worker):
153156
manager.env_workers[1].previous_step,
154157
]
155158

159+
@mock.patch(
160+
"mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.create_worker"
161+
)
162+
def test_crashed_env_restarts(self, mock_create_worker):
163+
crashing_worker = MockEnvWorker(
164+
0, EnvironmentResponse(EnvironmentCommand.RESET, 0, 0)
165+
)
166+
restarting_worker = MockEnvWorker(
167+
0, EnvironmentResponse(EnvironmentCommand.RESET, 0, 0)
168+
)
169+
healthy_worker = MockEnvWorker(
170+
1, EnvironmentResponse(EnvironmentCommand.RESET, 1, 1)
171+
)
172+
mock_create_worker.side_effect = [
173+
crashing_worker,
174+
healthy_worker,
175+
restarting_worker,
176+
]
177+
manager = SubprocessEnvManager(mock_env_factory, RunOptions(), 2)
178+
manager.step_queue = Mock()
179+
manager.step_queue.get_nowait.side_effect = [
180+
EnvironmentResponse(
181+
EnvironmentCommand.ENV_EXITED,
182+
0,
183+
UnityCommunicationException("Test msg"),
184+
),
185+
EnvironmentResponse(EnvironmentCommand.CLOSED, 0, None),
186+
EnvironmentResponse(EnvironmentCommand.STEP, 1, StepResponse(0, None, {})),
187+
EmptyQueue(),
188+
EnvironmentResponse(EnvironmentCommand.STEP, 0, StepResponse(1, None, {})),
189+
EnvironmentResponse(EnvironmentCommand.STEP, 1, StepResponse(2, None, {})),
190+
EmptyQueue(),
191+
]
192+
step_mock = Mock()
193+
last_steps = [Mock(), Mock(), Mock()]
194+
assert crashing_worker is manager.env_workers[0]
195+
assert healthy_worker is manager.env_workers[1]
196+
crashing_worker.previous_step = last_steps[0]
197+
crashing_worker.waiting = True
198+
healthy_worker.previous_step = last_steps[1]
199+
healthy_worker.waiting = True
200+
manager._take_step = Mock(return_value=step_mock)
201+
manager._step()
202+
healthy_worker.send.assert_has_calls(
203+
[
204+
call(EnvironmentCommand.ENVIRONMENT_PARAMETERS, ANY),
205+
call(EnvironmentCommand.RESET, ANY),
206+
call(EnvironmentCommand.STEP, ANY),
207+
]
208+
)
209+
restarting_worker.send.assert_has_calls(
210+
[
211+
call(EnvironmentCommand.ENVIRONMENT_PARAMETERS, ANY),
212+
call(EnvironmentCommand.RESET, ANY),
213+
call(EnvironmentCommand.STEP, ANY),
214+
]
215+
)
216+
156217
@mock.patch("mlagents.trainers.subprocess_env_manager.SubprocessEnvManager._step")
157218
@mock.patch(
158219
"mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.training_behaviors",

0 commit comments

Comments
 (0)