1
+ import datetime
1
2
from typing import Dict , NamedTuple , List , Any , Optional , Callable , Set
2
3
import cloudpickle
3
4
import enum
@@ -251,6 +252,14 @@ def __init__(
251
252
self .env_workers : List [UnityEnvWorker ] = []
252
253
self .step_queue : Queue = Queue ()
253
254
self .workers_alive = 0
255
+ self .env_factory = env_factory
256
+ self .run_options = run_options
257
+ self .env_parameters : Optional [Dict ] = None
258
+ # Each worker is correlated with a list of times they restarted within the last time period.
259
+ self .recent_restart_timestamps : List [List [datetime .datetime ]] = [
260
+ [] for _ in range (n_env )
261
+ ]
262
+ self .restart_counts : List [int ] = [0 ] * n_env
254
263
for worker_idx in range (n_env ):
255
264
self .env_workers .append (
256
265
self .create_worker (
@@ -293,6 +302,105 @@ def _queue_steps(self) -> None:
293
302
env_worker .send (EnvironmentCommand .STEP , env_action_info )
294
303
env_worker .waiting = True
295
304
305
+ def _restart_failed_workers (self , first_failure : EnvironmentResponse ) -> None :
306
+ if first_failure .cmd != EnvironmentCommand .ENV_EXITED :
307
+ return
308
+ # Drain the step queue to make sure all workers are paused and we have found all concurrent errors.
309
+ # Pausing all training is needed since we need to reset all pending training steps as they could be corrupted.
310
+ other_failures : Dict [int , Exception ] = self ._drain_step_queue ()
311
+ # TODO: Once we use python 3.9 switch to using the | operator to combine dicts.
312
+ failures : Dict [int , Exception ] = {
313
+ ** {first_failure .worker_id : first_failure .payload },
314
+ ** other_failures ,
315
+ }
316
+ for worker_id , ex in failures .items ():
317
+ self ._assert_worker_can_restart (worker_id , ex )
318
+ logger .warning (f"Restarting worker[{ worker_id } ] after '{ ex } '" )
319
+ self .recent_restart_timestamps [worker_id ].append (datetime .datetime .now ())
320
+ self .restart_counts [worker_id ] += 1
321
+ self .env_workers [worker_id ] = self .create_worker (
322
+ worker_id , self .step_queue , self .env_factory , self .run_options
323
+ )
324
+ # The restarts were successful, clear all the existing training trajectories so we don't use corrupted or
325
+ # outdated data.
326
+ self .reset (self .env_parameters )
327
+
328
+ def _drain_step_queue (self ) -> Dict [int , Exception ]:
329
+ """
330
+ Drains all steps out of the step queue and returns all exceptions from crashed workers.
331
+ This will effectively pause all workers so that they won't do anything until _queue_steps is called.
332
+ """
333
+ all_failures = {}
334
+ workers_still_pending = {w .worker_id for w in self .env_workers if w .waiting }
335
+ deadline = datetime .datetime .now () + datetime .timedelta (minutes = 1 )
336
+ while workers_still_pending and deadline > datetime .datetime .now ():
337
+ try :
338
+ while True :
339
+ step : EnvironmentResponse = self .step_queue .get_nowait ()
340
+ if step .cmd == EnvironmentCommand .ENV_EXITED :
341
+ workers_still_pending .add (step .worker_id )
342
+ all_failures [step .worker_id ] = step .payload
343
+ else :
344
+ workers_still_pending .remove (step .worker_id )
345
+ self .env_workers [step .worker_id ].waiting = False
346
+ except EmptyQueueException :
347
+ pass
348
+ if deadline < datetime .datetime .now ():
349
+ still_waiting = {w .worker_id for w in self .env_workers if w .waiting }
350
+ raise TimeoutError (f"Workers { still_waiting } stuck in waiting state" )
351
+ return all_failures
352
+
353
+ def _assert_worker_can_restart (self , worker_id : int , exception : Exception ) -> None :
354
+ """
355
+ Checks if we can recover from an exception from a worker.
356
+ If the restart limit is exceeded it will raise a UnityCommunicationException.
357
+ If the exception is not recoverable it re-raises the exception.
358
+ """
359
+ if (
360
+ isinstance (exception , UnityCommunicationException )
361
+ or isinstance (exception , UnityTimeOutException )
362
+ or isinstance (exception , UnityEnvironmentException )
363
+ or isinstance (exception , UnityCommunicatorStoppedException )
364
+ ):
365
+ if self ._worker_has_restart_quota (worker_id ):
366
+ return
367
+ else :
368
+ logger .error (
369
+ f"Worker { worker_id } exceeded the allowed number of restarts."
370
+ )
371
+ raise exception
372
+ raise exception
373
+
374
+ def _worker_has_restart_quota (self , worker_id : int ) -> bool :
375
+ self ._drop_old_restart_timestamps (worker_id )
376
+ max_lifetime_restarts = self .run_options .env_settings .max_lifetime_restarts
377
+ max_limit_check = (
378
+ max_lifetime_restarts == - 1
379
+ or self .restart_counts [worker_id ] < max_lifetime_restarts
380
+ )
381
+
382
+ rate_limit_n = self .run_options .env_settings .restarts_rate_limit_n
383
+ rate_limit_check = (
384
+ rate_limit_n == - 1
385
+ or len (self .recent_restart_timestamps [worker_id ]) < rate_limit_n
386
+ )
387
+
388
+ return rate_limit_check and max_limit_check
389
+
390
+ def _drop_old_restart_timestamps (self , worker_id : int ) -> None :
391
+ """
392
+ Drops environment restart timestamps that are outside of the current window.
393
+ """
394
+
395
+ def _filter (t : datetime .datetime ) -> bool :
396
+ return t > datetime .datetime .now () - datetime .timedelta (
397
+ seconds = self .run_options .env_settings .restarts_rate_limit_period_s
398
+ )
399
+
400
+ self .recent_restart_timestamps [worker_id ] = list (
401
+ filter (_filter , self .recent_restart_timestamps [worker_id ])
402
+ )
403
+
296
404
def _step (self ) -> List [EnvironmentStep ]:
297
405
# Queue steps for any workers which aren't in the "waiting" state.
298
406
self ._queue_steps ()
@@ -306,15 +414,18 @@ def _step(self) -> List[EnvironmentStep]:
306
414
while True :
307
415
step : EnvironmentResponse = self .step_queue .get_nowait ()
308
416
if step .cmd == EnvironmentCommand .ENV_EXITED :
309
- env_exception : Exception = step .payload
310
- raise env_exception
311
- self .env_workers [step .worker_id ].waiting = False
312
- if step .worker_id not in step_workers :
417
+ # If even one env exits try to restart all envs that failed.
418
+ self ._restart_failed_workers (step )
419
+ # Clear state and restart this function.
420
+ worker_steps .clear ()
421
+ step_workers .clear ()
422
+ self ._queue_steps ()
423
+ elif step .worker_id not in step_workers :
424
+ self .env_workers [step .worker_id ].waiting = False
313
425
worker_steps .append (step )
314
426
step_workers .add (step .worker_id )
315
427
except EmptyQueueException :
316
428
pass
317
-
318
429
step_infos = self ._postprocess_steps (worker_steps )
319
430
return step_infos
320
431
@@ -339,6 +450,7 @@ def set_env_parameters(self, config: Dict = None) -> None:
339
450
EnvironmentParametersSidehannel for each worker.
340
451
:param config: Dict of environment parameter keys and values
341
452
"""
453
+ self .env_parameters = config
342
454
for ew in self .env_workers :
343
455
ew .send (EnvironmentCommand .ENVIRONMENT_PARAMETERS , config )
344
456
0 commit comments