Skip to content

Commit a867e13

Browse files
author
Henry Peteet
committed
Basic example running with multiple envs in sb3
1 parent 71b121c commit a867e13

File tree

2 files changed

+103
-2
lines changed

2 files changed

+103
-2
lines changed

gym-unity/gym_unity/envs/__init__.py

+52-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
import itertools
2+
from dataclasses import dataclass
3+
24
import numpy as np
3-
from typing import Any, Dict, List, Optional, Tuple, Union
5+
from typing import Any, Dict, List, Optional, Tuple, Union, Callable
46

57
import gym
6-
from gym import error, spaces
8+
from gym import error, spaces, Env
9+
from stable_baselines3.common.vec_env import VecEnv, SubprocVecEnv
710

811
from mlagents_envs.base_env import ActionTuple, BaseEnv
912
from mlagents_envs.base_env import DecisionSteps, TerminalSteps
1013
from mlagents_envs import logging_util
14+
from mlagents_envs.environment import UnityEnvironment
15+
from mlagents_envs.side_channel.engine_configuration_channel import (
16+
EngineConfig,
17+
EngineConfigurationChannel,
18+
)
1119

1220

1321
class UnityGymException(error.Error):
@@ -23,6 +31,48 @@ class UnityGymException(error.Error):
2331

2432
GymStepResult = Tuple[np.ndarray, float, bool, Dict]
2533

34+
# Default values from CLI (See cli_utils.py)
35+
DEFAULT_ENGINE_CONFIG = EngineConfig(
36+
width=84,
37+
height=84,
38+
quality_level=4,
39+
time_scale=20,
40+
target_frame_rate=-1,
41+
capture_frame_rate=60,
42+
)
43+
44+
45+
# Some config subset of an actual config.yaml file for MLA.
46+
@dataclass
47+
class LimitedConfig:
48+
env_path: str
49+
num_env: int = 1
50+
engine_config: EngineConfig = DEFAULT_ENGINE_CONFIG
51+
52+
53+
def make_mla_sb3_env(config: LimitedConfig) -> VecEnv:
54+
def create_env(path: str, worker_id: int, seed: int) -> Callable[[], Env]:
55+
def _f() -> Env:
56+
engine_configuration_channel = EngineConfigurationChannel()
57+
engine_configuration_channel.set_configuration(config.engine_config)
58+
side_channels = [engine_configuration_channel]
59+
return UnityToGymWrapper(
60+
UnityEnvironment(
61+
file_name=path,
62+
worker_id=worker_id,
63+
seed=seed,
64+
side_channels=side_channels,
65+
),
66+
uint8_visual=True,
67+
)
68+
69+
return _f
70+
71+
env_facts = [
72+
create_env(config.env_path, worker_id=x, seed=x) for x in range(config.num_env)
73+
]
74+
return SubprocVecEnv(env_facts)
75+
2676

2777
class UnityToGymWrapper(gym.Env):
2878
"""

sb3_examples/3dball_num_envs.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from math import ceil
2+
3+
from baselines.common.schedules import LinearSchedule
4+
from stable_baselines3 import PPO
5+
from stable_baselines3.common.vec_env import VecMonitor
6+
7+
from gym_unity.envs import make_mla_sb3_env, LimitedConfig
8+
9+
TOTAL_TAINING_STEPS_GOAL = (
10+
500000
11+
) # Same as config for CI 3dball... Not sure if MLA steps == SB3 steps.
12+
NUM_ENVS = 12
13+
STEPS_PER_UPDATE = 2048
14+
15+
16+
# NOTE: This only achieves ~90/100 reward and is just a POC. Needs tuning to be useful.
17+
def main():
18+
env = make_mla_sb3_env(
19+
LimitedConfig(
20+
env_path="/Users/henry.peteet/Documents/RandomBuilds/3DBallSingleNoVis",
21+
num_env=NUM_ENVS,
22+
)
23+
)
24+
# Log results in the "results" folder
25+
env = VecMonitor(env, "results")
26+
# Attempt to approximate settings from 3DBall.yaml
27+
schedule = LinearSchedule(
28+
schedule_timesteps=TOTAL_TAINING_STEPS_GOAL, final_p=0.0, initial_p=0.0003
29+
)
30+
model = PPO(
31+
"MlpPolicy",
32+
env,
33+
verbose=1,
34+
# TODO: Check if I am using schedule correctly.
35+
learning_rate=lambda progress: schedule.value(
36+
TOTAL_TAINING_STEPS_GOAL * progress
37+
),
38+
tensorboard_log="results",
39+
n_steps=int(STEPS_PER_UPDATE),
40+
)
41+
training_rounds = ceil(TOTAL_TAINING_STEPS_GOAL / int(STEPS_PER_UPDATE * NUM_ENVS))
42+
for i in range(training_rounds):
43+
print(f"Training round {i + 1}/{training_rounds}")
44+
# NOTE: rest_num_timesteps should only happen the first time so that tensorboard logs are consistent.
45+
model.learn(total_timesteps=6000, reset_num_timesteps=(i == 0))
46+
model.policy.eval()
47+
env.close()
48+
49+
50+
if __name__ == "__main__":
51+
main()

0 commit comments

Comments
 (0)