diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index 091b736dcb..ebdfcc5484 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -22,6 +22,11 @@ and this project adheres to - Set gym version in gym-unity to gym release 0.20.0 - Added support for having `beta`, `epsilon`, and `learning rate` on separate schedules (affects only PPO and POCA). (#5538) +- Changed default behavior to restart crashed Unity environments rather than exiting. (#5553) + - Rate & lifetime limits on this are configurable via 3 new yaml options + 1. env_params.max_lifetime_restarts (--max-lifetime-restarts) [default=10] + 2. env_params.restarts_rate_limit_n (--restarts-rate-limit-n) [default=1] + 3. env_params.restarts_rate_limit_period_s (--restarts-rate-limit-period-s) [default=60] ### Bug Fixes #### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) diff --git a/docs/Training-ML-Agents.md b/docs/Training-ML-Agents.md index 427bfe6051..f62402e766 100644 --- a/docs/Training-ML-Agents.md +++ b/docs/Training-ML-Agents.md @@ -207,6 +207,9 @@ env_settings: base_port: 5005 num_envs: 1 seed: -1 + max_lifetime_restarts: 10 + restarts_rate_limit_n: 1 + restarts_rate_limit_period_s: 60 ``` #### Engine settings diff --git a/ml-agents-envs/mlagents_envs/rpc_communicator.py b/ml-agents-envs/mlagents_envs/rpc_communicator.py index e1f38a8260..13c0df08b1 100644 --- a/ml-agents-envs/mlagents_envs/rpc_communicator.py +++ b/ml-agents-envs/mlagents_envs/rpc_communicator.py @@ -58,7 +58,10 @@ def create_server(self): try: # Establish communication grpc - self.server = grpc.server(ThreadPoolExecutor(max_workers=10)) + self.server = grpc.server( + thread_pool=ThreadPoolExecutor(max_workers=10), + options=(("grpc.so_reuseport", 1),), + ) self.unity_to_external = UnityToExternalServicerImplementation() add_UnityToExternalProtoServicer_to_server( self.unity_to_external, self.server diff --git a/ml-agents/mlagents/trainers/cli_utils.py b/ml-agents/mlagents/trainers/cli_utils.py index 2cd36eb1f9..5f87ea094c 100644 --- a/ml-agents/mlagents/trainers/cli_utils.py +++ b/ml-agents/mlagents/trainers/cli_utils.py @@ -177,6 +177,26 @@ def _create_parser() -> argparse.ArgumentParser: "passed to the executable.", action=DetectDefault, ) + argparser.add_argument( + "--max-lifetime-restarts", + default=10, + help="The max number of times a single Unity executable can crash over its lifetime before ml-agents exits. " + "Can be set to -1 if no limit is desired.", + action=DetectDefault, + ) + argparser.add_argument( + "--restarts-rate-limit-n", + default=1, + help="The maximum number of times a single Unity executable can crash over a period of time (period set in " + "restarts-rate-limit-period-s). Can be set to -1 to not use rate limiting with restarts.", + action=DetectDefault, + ) + argparser.add_argument( + "--restarts-rate-limit-period-s", + default=60, + help="The period of time --restarts-rate-limit-n applies to.", + action=DetectDefault, + ) argparser.add_argument( "--torch", default=False, diff --git a/ml-agents/mlagents/trainers/settings.py b/ml-agents/mlagents/trainers/settings.py index 6b8fd8bfec..fbdf317636 100644 --- a/ml-agents/mlagents/trainers/settings.py +++ b/ml-agents/mlagents/trainers/settings.py @@ -823,6 +823,11 @@ class EnvironmentSettings: base_port: int = parser.get_default("base_port") num_envs: int = attr.ib(default=parser.get_default("num_envs")) seed: int = parser.get_default("seed") + max_lifetime_restarts: int = parser.get_default("max_lifetime_restarts") + restarts_rate_limit_n: int = parser.get_default("restarts_rate_limit_n") + restarts_rate_limit_period_s: int = parser.get_default( + "restarts_rate_limit_period_s" + ) @num_envs.validator def validate_num_envs(self, attribute, value): diff --git a/ml-agents/mlagents/trainers/subprocess_env_manager.py b/ml-agents/mlagents/trainers/subprocess_env_manager.py index e96703f62c..eac64546a0 100644 --- a/ml-agents/mlagents/trainers/subprocess_env_manager.py +++ b/ml-agents/mlagents/trainers/subprocess_env_manager.py @@ -1,3 +1,4 @@ +import datetime from typing import Dict, NamedTuple, List, Any, Optional, Callable, Set import cloudpickle import enum @@ -251,6 +252,14 @@ def __init__( self.env_workers: List[UnityEnvWorker] = [] self.step_queue: Queue = Queue() self.workers_alive = 0 + self.env_factory = env_factory + self.run_options = run_options + self.env_parameters: Optional[Dict] = None + # Each worker is correlated with a list of times they restarted within the last time period. + self.recent_restart_timestamps: List[List[datetime.datetime]] = [ + [] for _ in range(n_env) + ] + self.restart_counts: List[int] = [0] * n_env for worker_idx in range(n_env): self.env_workers.append( self.create_worker( @@ -293,6 +302,105 @@ def _queue_steps(self) -> None: env_worker.send(EnvironmentCommand.STEP, env_action_info) env_worker.waiting = True + def _restart_failed_workers(self, first_failure: EnvironmentResponse) -> None: + if first_failure.cmd != EnvironmentCommand.ENV_EXITED: + return + # Drain the step queue to make sure all workers are paused and we have found all concurrent errors. + # Pausing all training is needed since we need to reset all pending training steps as they could be corrupted. + other_failures: Dict[int, Exception] = self._drain_step_queue() + # TODO: Once we use python 3.9 switch to using the | operator to combine dicts. + failures: Dict[int, Exception] = { + **{first_failure.worker_id: first_failure.payload}, + **other_failures, + } + for worker_id, ex in failures.items(): + self._assert_worker_can_restart(worker_id, ex) + logger.warning(f"Restarting worker[{worker_id}] after '{ex}'") + self.recent_restart_timestamps[worker_id].append(datetime.datetime.now()) + self.restart_counts[worker_id] += 1 + self.env_workers[worker_id] = self.create_worker( + worker_id, self.step_queue, self.env_factory, self.run_options + ) + # The restarts were successful, clear all the existing training trajectories so we don't use corrupted or + # outdated data. + self.reset(self.env_parameters) + + def _drain_step_queue(self) -> Dict[int, Exception]: + """ + Drains all steps out of the step queue and returns all exceptions from crashed workers. + This will effectively pause all workers so that they won't do anything until _queue_steps is called. + """ + all_failures = {} + workers_still_pending = {w.worker_id for w in self.env_workers if w.waiting} + deadline = datetime.datetime.now() + datetime.timedelta(minutes=1) + while workers_still_pending and deadline > datetime.datetime.now(): + try: + while True: + step: EnvironmentResponse = self.step_queue.get_nowait() + if step.cmd == EnvironmentCommand.ENV_EXITED: + workers_still_pending.add(step.worker_id) + all_failures[step.worker_id] = step.payload + else: + workers_still_pending.remove(step.worker_id) + self.env_workers[step.worker_id].waiting = False + except EmptyQueueException: + pass + if deadline < datetime.datetime.now(): + still_waiting = {w.worker_id for w in self.env_workers if w.waiting} + raise TimeoutError(f"Workers {still_waiting} stuck in waiting state") + return all_failures + + def _assert_worker_can_restart(self, worker_id: int, exception: Exception) -> None: + """ + Checks if we can recover from an exception from a worker. + If the restart limit is exceeded it will raise a UnityCommunicationException. + If the exception is not recoverable it re-raises the exception. + """ + if ( + isinstance(exception, UnityCommunicationException) + or isinstance(exception, UnityTimeOutException) + or isinstance(exception, UnityEnvironmentException) + or isinstance(exception, UnityCommunicatorStoppedException) + ): + if self._worker_has_restart_quota(worker_id): + return + else: + logger.error( + f"Worker {worker_id} exceeded the allowed number of restarts." + ) + raise exception + raise exception + + def _worker_has_restart_quota(self, worker_id: int) -> bool: + self._drop_old_restart_timestamps(worker_id) + max_lifetime_restarts = self.run_options.env_settings.max_lifetime_restarts + max_limit_check = ( + max_lifetime_restarts == -1 + or self.restart_counts[worker_id] < max_lifetime_restarts + ) + + rate_limit_n = self.run_options.env_settings.restarts_rate_limit_n + rate_limit_check = ( + rate_limit_n == -1 + or len(self.recent_restart_timestamps[worker_id]) < rate_limit_n + ) + + return rate_limit_check and max_limit_check + + def _drop_old_restart_timestamps(self, worker_id: int) -> None: + """ + Drops environment restart timestamps that are outside of the current window. + """ + + def _filter(t: datetime.datetime) -> bool: + return t > datetime.datetime.now() - datetime.timedelta( + seconds=self.run_options.env_settings.restarts_rate_limit_period_s + ) + + self.recent_restart_timestamps[worker_id] = list( + filter(_filter, self.recent_restart_timestamps[worker_id]) + ) + def _step(self) -> List[EnvironmentStep]: # Queue steps for any workers which aren't in the "waiting" state. self._queue_steps() @@ -306,15 +414,18 @@ def _step(self) -> List[EnvironmentStep]: while True: step: EnvironmentResponse = self.step_queue.get_nowait() if step.cmd == EnvironmentCommand.ENV_EXITED: - env_exception: Exception = step.payload - raise env_exception - self.env_workers[step.worker_id].waiting = False - if step.worker_id not in step_workers: + # If even one env exits try to restart all envs that failed. + self._restart_failed_workers(step) + # Clear state and restart this function. + worker_steps.clear() + step_workers.clear() + self._queue_steps() + elif step.worker_id not in step_workers: + self.env_workers[step.worker_id].waiting = False worker_steps.append(step) step_workers.add(step.worker_id) except EmptyQueueException: pass - step_infos = self._postprocess_steps(worker_steps) return step_infos @@ -339,6 +450,7 @@ def set_env_parameters(self, config: Dict = None) -> None: EnvironmentParametersSidehannel for each worker. :param config: Dict of environment parameter keys and values """ + self.env_parameters = config for ew in self.env_workers: ew.send(EnvironmentCommand.ENVIRONMENT_PARAMETERS, config) diff --git a/ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py b/ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py index 4fd9cc96c2..553b0c11b3 100644 --- a/ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py +++ b/ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py @@ -1,5 +1,5 @@ from unittest import mock -from unittest.mock import Mock, MagicMock +from unittest.mock import Mock, MagicMock, call, ANY import unittest import pytest from queue import Empty as EmptyQueue @@ -14,7 +14,10 @@ from mlagents.trainers.env_manager import EnvironmentStep from mlagents_envs.base_env import BaseEnv from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod -from mlagents_envs.exception import UnityEnvironmentException +from mlagents_envs.exception import ( + UnityEnvironmentException, + UnityCommunicationException, +) from mlagents.trainers.tests.simple_test_envs import ( SimpleEnvironment, UnexpectedExceptionEnvironment, @@ -153,6 +156,64 @@ def test_step_takes_steps_for_all_non_waiting_envs(self, mock_create_worker): manager.env_workers[1].previous_step, ] + @mock.patch( + "mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.create_worker" + ) + def test_crashed_env_restarts(self, mock_create_worker): + crashing_worker = MockEnvWorker( + 0, EnvironmentResponse(EnvironmentCommand.RESET, 0, 0) + ) + restarting_worker = MockEnvWorker( + 0, EnvironmentResponse(EnvironmentCommand.RESET, 0, 0) + ) + healthy_worker = MockEnvWorker( + 1, EnvironmentResponse(EnvironmentCommand.RESET, 1, 1) + ) + mock_create_worker.side_effect = [ + crashing_worker, + healthy_worker, + restarting_worker, + ] + manager = SubprocessEnvManager(mock_env_factory, RunOptions(), 2) + manager.step_queue = Mock() + manager.step_queue.get_nowait.side_effect = [ + EnvironmentResponse( + EnvironmentCommand.ENV_EXITED, + 0, + UnityCommunicationException("Test msg"), + ), + EnvironmentResponse(EnvironmentCommand.CLOSED, 0, None), + EnvironmentResponse(EnvironmentCommand.STEP, 1, StepResponse(0, None, {})), + EmptyQueue(), + EnvironmentResponse(EnvironmentCommand.STEP, 0, StepResponse(1, None, {})), + EnvironmentResponse(EnvironmentCommand.STEP, 1, StepResponse(2, None, {})), + EmptyQueue(), + ] + step_mock = Mock() + last_steps = [Mock(), Mock(), Mock()] + assert crashing_worker is manager.env_workers[0] + assert healthy_worker is manager.env_workers[1] + crashing_worker.previous_step = last_steps[0] + crashing_worker.waiting = True + healthy_worker.previous_step = last_steps[1] + healthy_worker.waiting = True + manager._take_step = Mock(return_value=step_mock) + manager._step() + healthy_worker.send.assert_has_calls( + [ + call(EnvironmentCommand.ENVIRONMENT_PARAMETERS, ANY), + call(EnvironmentCommand.RESET, ANY), + call(EnvironmentCommand.STEP, ANY), + ] + ) + restarting_worker.send.assert_has_calls( + [ + call(EnvironmentCommand.ENVIRONMENT_PARAMETERS, ANY), + call(EnvironmentCommand.RESET, ANY), + call(EnvironmentCommand.STEP, ANY), + ] + ) + @mock.patch("mlagents.trainers.subprocess_env_manager.SubprocessEnvManager._step") @mock.patch( "mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.training_behaviors",