1
1
import itertools
2
+ from dataclasses import dataclass
3
+
2
4
import numpy as np
3
- from typing import Any , Dict , List , Optional , Tuple , Union
5
+ from typing import Any , Dict , List , Optional , Tuple , Union , Callable
4
6
5
7
import gym
6
- from gym import error , spaces
8
+ from gym import error , spaces , Env
9
+ from stable_baselines3 .common .vec_env import VecEnv , SubprocVecEnv
7
10
8
11
from mlagents_envs .base_env import ActionTuple , BaseEnv
9
12
from mlagents_envs .base_env import DecisionSteps , TerminalSteps
10
13
from mlagents_envs import logging_util
14
+ from mlagents_envs .environment import UnityEnvironment
15
+ from mlagents_envs .side_channel .engine_configuration_channel import (
16
+ EngineConfig ,
17
+ EngineConfigurationChannel ,
18
+ )
11
19
12
20
13
21
class UnityGymException (error .Error ):
@@ -23,6 +31,48 @@ class UnityGymException(error.Error):
23
31
24
32
GymStepResult = Tuple [np .ndarray , float , bool , Dict ]
25
33
34
+ # Default values from CLI (See cli_utils.py)
35
+ DEFAULT_ENGINE_CONFIG = EngineConfig (
36
+ width = 84 ,
37
+ height = 84 ,
38
+ quality_level = 4 ,
39
+ time_scale = 20 ,
40
+ target_frame_rate = - 1 ,
41
+ capture_frame_rate = 60 ,
42
+ )
43
+
44
+
45
+ # Some config subset of an actual config.yaml file for MLA.
46
+ @dataclass
47
+ class LimitedConfig :
48
+ env_path : str
49
+ num_env : int = 1
50
+ engine_config : EngineConfig = DEFAULT_ENGINE_CONFIG
51
+
52
+
53
+ def make_mla_sb3_env (config : LimitedConfig ) -> VecEnv :
54
+ def create_env (path : str , worker_id : int , seed : int ) -> Callable [[], Env ]:
55
+ def _f () -> Env :
56
+ engine_configuration_channel = EngineConfigurationChannel ()
57
+ engine_configuration_channel .set_configuration (config .engine_config )
58
+ side_channels = [engine_configuration_channel ]
59
+ return UnityToGymWrapper (
60
+ UnityEnvironment (
61
+ file_name = path ,
62
+ worker_id = worker_id ,
63
+ seed = seed ,
64
+ side_channels = side_channels ,
65
+ ),
66
+ uint8_visual = True ,
67
+ )
68
+
69
+ return _f
70
+
71
+ env_facts = [
72
+ create_env (config .env_path , worker_id = x , seed = x ) for x in range (config .num_env )
73
+ ]
74
+ return SubprocVecEnv (env_facts )
75
+
26
76
27
77
class UnityToGymWrapper (gym .Env ):
28
78
"""
0 commit comments