Skip to content

Commit 9e03966

Browse files
author
Chris Elion
authored
[MLA-1952] Add optional seed for gym action spaces (#5303) (#5315)
1 parent b629b49 commit 9e03966

File tree

3 files changed

+25
-9
lines changed

3 files changed

+25
-9
lines changed

com.unity.ml-agents/CHANGELOG.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
66
and this project adheres to
77
[Semantic Versioning](http://semver.org/spec/v2.0.0.html).
88

9-
109
## [2.0.0-exp.1] - 2021-04-22
1110
### Major Changes
1211
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
@@ -77,6 +76,9 @@ or actuators on your system. (#5194)
7776
#### ml-agents / ml-agents-envs / gym-unity (Python)
7877
- Fixed a bug where --results-dir has no effect. (#5269)
7978
- Fixed a bug where old `.pt` checkpoints were not deleted during training. (#5271)
79+
- The `UnityToGymWrapper` initializer now accepts an optional `action_space_seed` seed. If this is specified, it will
80+
be used to set the random seed on the resulting action space. (#5303)
81+
8082

8183
## [1.9.1-preview] - 2021-04-13
8284
### Major Changes

gym-unity/gym_unity/envs/__init__.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import itertools
22
import numpy as np
3-
from typing import Any, Dict, List, Tuple, Union
3+
from typing import Any, Dict, List, Optional, Tuple, Union
44

55
import gym
66
from gym import error, spaces
@@ -35,6 +35,7 @@ def __init__(
3535
uint8_visual: bool = False,
3636
flatten_branched: bool = False,
3737
allow_multiple_obs: bool = False,
38+
action_space_seed: Optional[int] = None,
3839
):
3940
"""
4041
Environment initialization
@@ -46,6 +47,7 @@ def __init__(
4647
containing the visual observations and the last element containing the array of vector observations.
4748
If False, returns a single np.ndarray containing either only a single visual observation or the array of
4849
vector observations.
50+
:param action_space_seed: If non-None, will be used to set the random seed on created gym.Space instances.
4951
"""
5052
self._env = unity_env
5153

@@ -130,6 +132,9 @@ def __init__(
130132
"and continuous actions."
131133
)
132134

135+
if action_space_seed is not None:
136+
self._action_space.seed(action_space_seed)
137+
133138
# Set observations space
134139
list_spaces: List[gym.Space] = []
135140
shapes = self._get_vis_obs_shape()
@@ -305,7 +310,7 @@ def reward_range(self) -> Tuple[float, float]:
305310
return -float("inf"), float("inf")
306311

307312
@property
308-
def action_space(self):
313+
def action_space(self) -> gym.Space:
309314
return self._action_space
310315

311316
@property

gym-unity/gym_unity/tests/test_gym.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ def test_gym_wrapper():
2222
mock_env, mock_spec, mock_decision_step, mock_terminal_step
2323
)
2424
env = UnityToGymWrapper(mock_env)
25-
assert isinstance(env, UnityToGymWrapper)
2625
assert isinstance(env.reset(), np.ndarray)
2726
actions = env.action_space.sample()
2827
assert actions.shape[0] == 2
@@ -78,6 +77,21 @@ def test_action_space():
7877
assert env.action_space.n == 5
7978

8079

80+
def test_action_space_seed():
81+
mock_env = mock.MagicMock()
82+
mock_spec = create_mock_group_spec()
83+
mock_decision_step, mock_terminal_step = create_mock_vector_steps(mock_spec)
84+
setup_mock_unityenvironment(
85+
mock_env, mock_spec, mock_decision_step, mock_terminal_step
86+
)
87+
actions = []
88+
for _ in range(0, 2):
89+
env = UnityToGymWrapper(mock_env, action_space_seed=1337)
90+
env.reset()
91+
actions.append(env.action_space.sample())
92+
assert (actions[0] == actions[1]).all()
93+
94+
8195
@pytest.mark.parametrize("use_uint8", [True, False], ids=["float", "uint8"])
8296
def test_gym_wrapper_visual(use_uint8):
8397
mock_env = mock.MagicMock()
@@ -93,7 +107,6 @@ def test_gym_wrapper_visual(use_uint8):
93107

94108
env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8)
95109
assert isinstance(env.observation_space, spaces.Box)
96-
assert isinstance(env, UnityToGymWrapper)
97110
assert isinstance(env.reset(), np.ndarray)
98111
actions = env.action_space.sample()
99112
assert actions.shape[0] == 2
@@ -121,7 +134,6 @@ def test_gym_wrapper_single_visual_and_vector(use_uint8):
121134
)
122135

123136
env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8, allow_multiple_obs=True)
124-
assert isinstance(env, UnityToGymWrapper)
125137
assert isinstance(env.observation_space, spaces.Tuple)
126138
assert len(env.observation_space) == 2
127139
reset_obs = env.reset()
@@ -143,7 +155,6 @@ def test_gym_wrapper_single_visual_and_vector(use_uint8):
143155

144156
# check behavior for allow_multiple_obs = False
145157
env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8, allow_multiple_obs=False)
146-
assert isinstance(env, UnityToGymWrapper)
147158
assert isinstance(env.observation_space, spaces.Box)
148159
reset_obs = env.reset()
149160
assert isinstance(reset_obs, np.ndarray)
@@ -170,7 +181,6 @@ def test_gym_wrapper_multi_visual_and_vector(use_uint8):
170181
)
171182

172183
env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8, allow_multiple_obs=True)
173-
assert isinstance(env, UnityToGymWrapper)
174184
assert isinstance(env.observation_space, spaces.Tuple)
175185
assert len(env.observation_space) == 3
176186
reset_obs = env.reset()
@@ -188,7 +198,6 @@ def test_gym_wrapper_multi_visual_and_vector(use_uint8):
188198

189199
# check behavior for allow_multiple_obs = False
190200
env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8, allow_multiple_obs=False)
191-
assert isinstance(env, UnityToGymWrapper)
192201
assert isinstance(env.observation_space, spaces.Box)
193202
reset_obs = env.reset()
194203
assert isinstance(reset_obs, np.ndarray)

0 commit comments

Comments
 (0)